Revert "Port index.Tensor to structured kernels."

This reverts commit 9fe6f1baf5d1c130e717fc32d993bb531ed05b62.

Reverted https://github.com/pytorch/pytorch/pull/69607 on behalf of https://github.com/suo due to this broke master, see: 9fe6f1baf5
This commit is contained in:
PyTorch MergeBot
2022-06-01 00:12:15 +00:00
parent 1705be8ff7
commit fca1f495c2
18 changed files with 99 additions and 151 deletions

View File

@ -1 +1 @@
83b5fe8d25e4498352b59f5b2609b887ae224de9
5dae54bc53eb6c9a11eb4706fe01d1dfa557c14f

View File

@ -5,7 +5,6 @@ namespace at {
class Tensor;
class TensorBase;
struct TensorIterator;
struct TensorIteratorBase;
}
namespace c10 {
@ -14,7 +13,7 @@ class Scalar;
namespace at { namespace native {
using index_fn = void(*)(TensorIteratorBase &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides);
using index_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides);
using index_fill_fn = void(*)(TensorIterator & iter, int64_t dim, int64_t self_dim_size, int64_t self_dim_stride, const Scalar& source);
using index_copy_fn = void(*)(TensorIterator & iter, int64_t dim, int64_t self_dim_size, int64_t self_dim_stride);
using index_put_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides, bool accumulate);

View File

@ -2,7 +2,7 @@
#include <ATen/ExpandUtils.h>
#include <ATen/native/CanUse32BitIndexMath.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/core/IListRef.h>
#include <ATen/core/List.h>
#include <c10/util/irange.h>
namespace at { namespace native {
@ -14,14 +14,14 @@ static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask,
}
static C10_UNUSED std::vector<Tensor> expandTensors(const Tensor & self, IOptTensorListRef indices) {
static C10_UNUSED std::vector<Tensor> expandTensors(const Tensor & self, const torch::List<c10::optional<Tensor>>& indices) {
// If indices come in as ByteTensor or BoolTensor (masks), expand them into the equivalent indexing by LongTensors
std::vector<Tensor> result;
for (const auto& index_opt : indices) {
for (c10::optional<Tensor> index_opt : indices) {
if (!index_opt.has_value()) {
result.emplace_back();
} else {
const auto& index = *index_opt;
Tensor index = std::move(*index_opt);
if (index.scalar_type() == kByte || index.scalar_type() == kBool) {
if (index.scalar_type() == kByte) {
TORCH_WARN("indexing with dtype torch.uint8 is now deprecated," \
@ -48,8 +48,9 @@ static C10_UNUSED std::vector<Tensor> expandTensors(const Tensor & self, IOptTen
return result;
}
static C10_UNUSED void checkIndexTensorTypes(IOptTensorListRef indices) {
for (const auto& tensor : indices) {
static C10_UNUSED void checkIndexTensorTypes(const torch::List<c10::optional<Tensor>>& indices) {
for (c10::optional<Tensor> tensor : indices) {
if (tensor.has_value() && tensor->defined()) {
auto scalarType = tensor->scalar_type();
if (scalarType != kLong && scalarType != kByte && scalarType != kBool) {

View File

@ -56,7 +56,6 @@
#include <ATen/NativeFunctions.h>
#include <ATen/ExpandUtils.h>
#include <ATen/MemoryOverlap.h>
#include <ATen/core/IListRef.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/BinaryOps.h>
#include <ATen/native/Copy.h>
@ -75,13 +74,6 @@
#include <vector>
namespace at {
namespace native {
std::string shapes_as_str(TensorList tensors);
AdvancedIndex make_info(Tensor self, IOptTensorListRef orig);
} // namespace native
namespace meta {
native::SCATTER_GATHER_OP get_operator_enum(const c10::string_view reduce, bool use_new_options = false) {
@ -335,73 +327,6 @@ TORCH_PRECOMPUTE_META_FUNC(index_reduce)
return TORCH_PRECOMPUTE_STRUCT(index_reduce)().set_dim(dim);
}
static void build_index_op(
TensorIteratorBase& iter,
const at::native::AdvancedIndex& info,
const Tensor& result) {
// 'TensorIterator' needs to own the things comming from 'info', since
// 'info' will be destroyed after the META function.
TensorIteratorConfig config;
// info.src is a restrided view of result
config.set_check_mem_overlap(false)
.check_all_same_dtype(false)
.add_output(result)
.add_owned_input(info.src);
for (auto& index : info.indices) {
config.add_owned_input(index);
}
if (!result.defined()) {
config.declare_static_dtype_and_device(info.src.scalar_type(), info.src.device());
}
iter.build(config);
}
void check_indices_on_cpu_or_selfdevice(
const Tensor& self,
const at::MaterializedIOptTensorListRef& indices) {
auto dev = self.device();
bool indices_on_cpu_or_dev = std::all_of(
indices.begin(), indices.end(), [=](const at::OptionalTensorRef& opt) {
return opt.has_value() ? (opt->is_cpu() || opt->device() == dev) : true;
});
TORCH_CHECK(
indices_on_cpu_or_dev,
"indices should be either on ", kCPU,
" or on the same device as the indexed tensor (", dev, ")");
}
TORCH_PRECOMPUTE_META_FUNC2(index, Tensor)
(const Tensor& self, at::IOptTensorListRef indices) {
auto materialized = indices.materialize();
TORCH_CHECK_INDEX(
materialized.size() <= (size_t)self.dim(),
"too many indices for tensor of dimension ",
self.dim(), " (got ", materialized.size(), ")");
// Only allow: `dev_tensor[{cpu,dev}_tensor]`.
// See: https://github.com/pytorch/pytorch/pull/69607
check_indices_on_cpu_or_selfdevice(self, materialized);
const auto& result = maybe_get_output();
if (result.defined()) {
at::assert_no_internal_overlap(result);
at::assert_no_overlap(result, self);
for (const at::OptionalTensorRef& index : materialized) {
if (index.has_value()) {
at::assert_no_overlap(result, *index);
}
}
}
auto info = at::native::make_info(self, indices);
build_index_op(*this, info, result);
return TORCH_PRECOMPUTE_STRUCT2(index, Tensor)()
.set_sizes(std::move(info.indexed_sizes))
.set_strides(std::move(info.indexed_strides));
}
} // namespace meta
namespace native {
@ -438,7 +363,7 @@ static bool all_strides_match(TensorList tensors) {
return true;
}
inline std::string shapes_as_str(TensorList tensors) {
static std::string shapes_as_str(TensorList tensors) {
std::ostringstream os;
bool first = true;
for (auto& tensor : tensors) {
@ -565,7 +490,7 @@ const Tensor& value){
return std::make_tuple(true, mask);
}
inline AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) {
static AdvancedIndex make_info(Tensor self, const torch::List<c10::optional<at::Tensor>>& orig) {
checkIndexTensorTypes(orig);
// first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
auto indices = expandTensors(self, orig);
@ -574,7 +499,7 @@ inline AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) {
indices = expand_outplace(indices);
} catch (std::exception& e) {
TORCH_CHECK_INDEX(false, "shape mismatch: indexing tensors could not be broadcast together"
" with shapes ", shapes_as_str(indices));
" with shapes ", shapes_as_str(indices));
}
// add missing null Tensors so that it matches self.dim()
while (indices.size() < (size_t)self.dim()) {
@ -614,12 +539,39 @@ static TensorIterator make_index_put_iterator(const AdvancedIndex& info, const T
return config.build();
}
TORCH_IMPL_FUNC(index_out)
(const Tensor& self,
DimVector sizes,
DimVector strides,
const Tensor& result) {
index_stub(device_type(), *this, sizes, strides);
static TensorIterator make_index_iterator(const AdvancedIndex& info) {
TensorIteratorConfig config;
config.set_check_mem_overlap(false)
.check_all_same_dtype(false)
.declare_static_dtype_and_device(info.src.scalar_type(), info.src.device())
.add_owned_output(Tensor())
.add_input(info.src);
for (auto& index : info.indices) {
config.add_input(index);
}
return config.build();
}
static TensorIterator make_index_out_iterator(const AdvancedIndex& info, Tensor& result) {
TensorIteratorConfig config;
// info.src is a restrided view of result
config.set_check_mem_overlap(false)
.check_all_same_dtype(false)
.add_output(result)
.add_input(info.src);
for (auto& index : info.indices) {
config.add_input(index);
}
return config.build();
}
Tensor index(const Tensor & self, const torch::List<c10::optional<Tensor>>& indices) {
TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
auto info = make_info(self, indices);
auto iter = make_index_iterator(info);
index_stub(iter.device_type(), iter, info.indexed_sizes, info.indexed_strides);
return iter.output();
}
Tensor quantized_index(const Tensor & self, const torch::List<c10::optional<Tensor>>& indices) {
@ -631,9 +583,33 @@ Tensor quantized_index(const Tensor & self, const torch::List<c10::optional<Tens
// For now, this is a naive implementation which does dq -> index -> q.
// TODO(future PR): improve performance by removing the copies.
const auto& self_dq = self.dequantize();
auto result = at::index(self_dq, indices);
TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
auto info = make_info(self_dq, indices);
auto iter = make_index_iterator(info);
index_stub(iter.device_type(), iter, info.indexed_sizes, info.indexed_strides);
at::Tensor res = iter.output();
return at::quantize_per_tensor(
result, self.q_scale(), self.q_zero_point(), self.scalar_type());
res, self.q_scale(), self.q_zero_point(), self.scalar_type());
}
Tensor& index_out(Tensor& result, const Tensor & self, const torch::List<c10::optional<Tensor>>& indices) {
TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
at::assert_no_internal_overlap(result);
at::assert_no_overlap(result, self);
// NOLINTNEXTLINE(performance-implicit-conversion-in-loop)
for (const c10::optional<Tensor>& index: indices) {
if (index.has_value()) {
at::assert_no_overlap(result, *index);
}
}
auto info = make_info(self, indices);
auto iter = make_index_out_iterator(info, result);
index_stub(iter.device_type(), iter, info.indexed_sizes, info.indexed_strides);
return result;
}
Tensor & put_(Tensor & self, const Tensor& index, const Tensor & source, const bool accumulate) {

View File

@ -64,7 +64,7 @@ static bool is_constant_index(int ntensor, const int64_t* strides) {
}
template <typename scalar_t, typename func_t>
void cpu_index_kernel(TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride,
void cpu_index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride,
const func_t& f, bool serial_execution=false)
{
int ntensor = iter.ntensors();
@ -102,7 +102,7 @@ void cpu_index_kernel(TensorIteratorBase& iter, IntArrayRef index_size, IntArray
}
}
void index_kernel(TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride) {
void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBool, kBFloat16,
iter.dtype(), "index_cpu", [&] {
cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {

View File

@ -11,7 +11,6 @@
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/index_cuda_dispatch.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/masked_scatter_native.h>
#include <ATen/ops/masked_select_native.h>
@ -40,7 +39,7 @@ static Tensor & masked_select_out_cuda_impl(Tensor & result, const Tensor & self
// owning and expand_outplace returns a borrow, the returned borrow
// would dangle.
auto mask_self_expanded = expand_outplace(*mask_temp, *self_temp);
at::cuda::index_out(
at::native::index_out(
result, *std::get<1>(mask_self_expanded),
c10::List<c10::optional<at::Tensor>>({*std::move(std::get<0>(mask_self_expanded))}));

View File

@ -50,7 +50,7 @@ static void launch_kernel(int64_t N, const func_t& f) {
}
template <typename func_t>
void gpu_index_kernel(TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride, const func_t& f) {
void gpu_index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, const func_t& f) {
int num_indices = index_size.size();
AT_ASSERT(num_indices == index_stride.size());
AT_ASSERT(num_indices == iter.ntensors() - 2);
@ -178,7 +178,7 @@ void index_copy_kernel_impl(
}
template <typename scalar_t>
void index_kernel_impl(TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride) {
void index_kernel_impl(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride) {
gpu_index_kernel(iter, index_size, index_stride, []C10_DEVICE(char* out_data, char* in_data, int64_t offset) {
*(scalar_t*)out_data = *(scalar_t*)(in_data + offset);
});
@ -191,7 +191,7 @@ void index_put_kernel_impl(TensorIterator& iter, IntArrayRef index_size, IntArra
});
}
static void index_kernel(TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride) {
static void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBool, kBFloat16, iter.dtype(), "index_cuda", [&] {
using dtype = OpaqueType<sizeof(scalar_t)>;
index_kernel_impl<dtype>(iter, index_size, index_stride);

View File

@ -230,7 +230,7 @@ computeLinearIndex(const Tensor & src, TensorList indices, bool check_range) {
}
static std::tuple<Tensor, Tensor, int64_t, int64_t, int64_t, std::vector<int64_t>> makeLinearIndex(Tensor self, IOptTensorListRef orig, bool check_range) {
static std::tuple<Tensor, Tensor, int64_t, int64_t, int64_t, std::vector<int64_t>> makeLinearIndex(Tensor self, const c10::List<c10::optional<at::Tensor>>& orig, bool check_range) {
checkIndexTensorTypes(orig);
// first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
auto indices = expandTensors(self, orig);

View File

@ -2483,24 +2483,15 @@
- func: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
device_check: NoCheck # TensorIterator
structured_delegate: index.Tensor_out
variants: function, method
dispatch:
CPU, CUDA: index
QuantizedCPU: quantized_index
# NB: This function is special-cased in tools/autograd/gen_variable_type.py
# NB: The following functions are declared in aten/src/ATen/templates/TensorBody.h and defined in aten/src/ATen/TensorIndexing.cpp:
# - Tensor Tensor::index(ArrayRef<TensorIndex> indices)
# - Tensor Tensor::index(std::initializer_list<TensorIndex> indices)
- func: index.Tensor_out(Tensor self, Tensor?[] indices, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck
structured: True
structured_inherits: TensorIteratorBase
precomputed:
- indices -> DimVector sizes, DimVector strides
dispatch:
CPU, CUDA: index_out
- func: index_copy.out(Tensor self, int dim, Tensor index, Tensor source, *, Tensor(a!) out) -> Tensor(a!)
structured: True
variants: function

View File

@ -4720,7 +4720,7 @@ TEST_F(LazyOpsTest, TestOneIndexTransfer) {
torch::Tensor result = torch::index(params, {indices});
ForEachDevice([&](const torch::Device& device) {
torch::Tensor lazy_params = CopyToDevice(params, device);
torch::Tensor lazy_result = torch::index(lazy_params, {indices.cpu()});
torch::Tensor lazy_result = torch::index(lazy_params, {indices});
AllEqual(result, lazy_result);
});
}

View File

@ -808,6 +808,7 @@ meta_dispatch_expected_failures = {
aten.histogram.bin_ct: {f64, f32},
aten.histogram.bins_tensor: {f64, f32},
aten.im2col.default: {bf16, f16, f64, f32},
aten.index.Tensor: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32, c32},
aten.kthvalue.default: {i64, bf16, u8, f32, i8, f64, i16, i32},
aten.linalg_matrix_exp.default: {bf16, f64, f32},
aten.log_sigmoid_forward.output: {bf16, f64, f32},
@ -888,7 +889,6 @@ meta_dispatch_expected_failures = {
# these sometimes pass and sometimes fail
meta_dispatch_skips = {
aten.index.Tensor: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32, c32}, # at::nonzero doesn't have a Meta function
aten._to_copy.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
aten.aminmax.default: {i64, u8, b8, f32, i8, f64, i16, i32},
aten.cummax.default: {i64, bf16, u8, b8, f32, i8, f64, i16, i32},
@ -924,6 +924,7 @@ meta_dispatch_device_expected_failures['cuda'] = {
aten.grid_sampler_3d.default: {f16}, # aten::grid_sampler_3d
aten.histc.default: {i16, i32, i64, i8}, # aten::histc
aten.histc.out: {i16, i32, i64, i8}, # aten::histc.out
aten.index.Tensor: {c32}, # aten::index.Tensor
aten.kthvalue.default: {f16}, # aten::kthvalue.values
aten.linalg_cholesky_ex.L: {f32, f64}, # aten::linalg_cholesky_ex.L
aten.linalg_cholesky_ex.default: {f32, f64}, # aten::linalg_cholesky_ex

View File

@ -5529,9 +5529,11 @@ class TestDevicePrecision(TestCase):
cpu = torch.device('cpu')
for device in devices:
# Index cpu tensor with device tensor
x = torch.randn(3, 4, 4, 4, 3)
ia = torch.tensor([0, 2, 1])
ib = torch.tensor([0, 2, 1])
ia = torch.tensor([0, 2, 1]).to(device)
ib = torch.tensor([0, 2, 1]).to(device)
test(x, ia, ib)
# Index device tensor with cpu tensor
x = x.to(device)
@ -5539,38 +5541,22 @@ class TestDevicePrecision(TestCase):
ib = ib.to(cpu)
test(x, ia, ib)
# Index device tensor with mixed cpu, device tensors
x = x.to(device)
ia = ia.to(cpu)
ib = ib.to(device)
test(x, ia, ib)
@deviceCountAtLeast(1)
def test_advancedindex_mixed_devices_error(self, devices) -> None:
def test(x: torch.Tensor, ia: torch.Tensor, ib: torch.Tensor) -> None:
# test getitem
with self.assertRaisesRegex(RuntimeError, fr"indices should be either .* \({x.device}\)"):
value = x[:, ia, None, ib, 0]
with self.assertRaisesRegex(RuntimeError, fr"indices should be either .* \({x.device}\)"):
value = x[ib]
cpu = torch.device('cpu')
for device in devices:
# Index cpu tensor with device tensor
x = torch.randn(3, 4, 4, 4, 3)
ia = torch.tensor([0, 2, 1]).to(device)
ib = torch.tensor([0, 2, 1]).to(device)
test(x, ia, ib)
# Index cpu tensor with mixed cpu, device tensors
x = x.to(cpu)
ia = ia.to(cpu)
ib = ib.to(device)
test(x, ia, ib)
if len(devices) > 1:
other_device = devices[0] if device == devices[1] else devices[1]
# Index device tensor with mixed cpu, device tensors
x = x.to(device)
ia = ia.to(cpu)
ib = ib.to(device)
test(x, ia, ib)
if len(devices) > 1:
other_device = devices[0]
if device == devices[0]:
other_device = devices[1]
# Index device tensor with mixed cpu, device tensors on different devices
x = x.to(device)
ia = ia.to(cpu)

View File

@ -105,7 +105,6 @@ _SKIP_PYTHON_BINDINGS = [
"_sparse_sub.*",
"_sparse_dense_add_out",
"index",
"index_out",
"unique_dim_consecutive",
"_cumsum.*",
"_cumprod.*",

View File

@ -35,8 +35,7 @@ inline c10::List<c10::optional<Tensor>> unpack_opt_list(at::ArrayRef<SavedVariab
torch::List<c10::optional<Tensor>> result;
result.reserve(xs.size());
for (const SavedVariable& v : xs) {
auto var = v.unpack();
result.push_back(var.defined() ? c10::optional<Tensor>(var) : c10::nullopt);
result.push_back(v.unpack());
}
return result;
}

View File

@ -1219,12 +1219,12 @@ REGISTER_OPERATOR_FUNCTOR(aten::index, aten_index, [](Node* n) -> SROperator {
const auto in1_l =
at::native::toListOfOptionalTensors(p_node->Input(1).toListRef());
if (p_node->Output(0).isNone()) {
p_node->Output(0) = at::cpu::index(in0_t, in1_l);
p_node->Output(0) = at::native::index(in0_t, in1_l);
return;
}
auto& out_t = p_node->Output(0).toTensor();
fastResizeToZero(out_t);
at::cpu::index_out(out_t, in0_t, in1_l);
at::native::index_out(out_t, in0_t, in1_l);
};
});

View File

@ -331,7 +331,7 @@ def pad_packed_sequence(
unsorted_indices = sequence.unsorted_indices
if unsorted_indices is not None:
batch_dim = 0 if batch_first else 1
return padded_output.index_select(batch_dim, unsorted_indices), lengths[unsorted_indices.cpu()]
return padded_output.index_select(batch_dim, unsorted_indices), lengths[unsorted_indices]
return padded_output, lengths

View File

@ -67,7 +67,6 @@ iTensorListRefT = BaseCppType("at", "ITensorListRef")
iOptTensorListRefT = BaseCppType("at", "IOptTensorListRef")
dimnameT = BaseCppType("at", "Dimname")
dimnameListT = BaseCppType("at", "DimnameList")
dimVectorT = BaseCppType("at", "DimVector")
layoutT = BaseCppType("at", "Layout")
deviceT = BaseCppType("at", "Device")
scalarT = BaseCppType("at", "Scalar")
@ -114,7 +113,6 @@ BaseTypeToCppMapping: Dict[BaseTy, BaseCppType] = {
BaseTy.ScalarType: scalarTypeT,
BaseTy.Tensor: tensorT,
BaseTy.Dimname: dimnameT,
BaseTy.DimVector: dimVectorT,
BaseTy.Layout: layoutT,
BaseTy.Device: deviceT,
BaseTy.Scalar: scalarT,

View File

@ -1500,7 +1500,6 @@ BaseTy = Enum(
"Tensor",
"int",
"Dimname",
"DimVector",
"float",
"str",
"bool",