mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-18 01:15:12 +08:00
96 lines
3.4 KiB
C++
96 lines
3.4 KiB
C++
#include <ATen/TensorIndexing.h>
|
|
|
|
#include <c10/util/Exception.h>
|
|
#include <c10/util/irange.h>
|
|
|
|
namespace at {
|
|
namespace indexing {
|
|
|
|
const EllipsisIndexType Ellipsis = EllipsisIndexType();
|
|
|
|
std::ostream& operator<<(std::ostream& stream, const Slice& slice) {
|
|
stream << slice.start() << ':' << slice.stop() << ':' << slice.step();
|
|
return stream;
|
|
}
|
|
|
|
std::ostream& operator<<(std::ostream& stream, const TensorIndex& tensor_index) {
|
|
if (tensor_index.is_none()) {
|
|
stream << "None";
|
|
} else if (tensor_index.is_ellipsis()) {
|
|
stream << "...";
|
|
} else if (tensor_index.is_integer()) {
|
|
stream << tensor_index.integer();
|
|
} else if (tensor_index.is_boolean()) {
|
|
stream << std::boolalpha << tensor_index.boolean();
|
|
} else if (tensor_index.is_slice()) {
|
|
stream << tensor_index.slice();
|
|
} else if (tensor_index.is_tensor()) {
|
|
stream << tensor_index.tensor();
|
|
}
|
|
return stream;
|
|
}
|
|
|
|
std::ostream& operator<<(std::ostream& stream, const std::vector<TensorIndex>& tensor_indices) {
|
|
stream << '(';
|
|
for (const auto i : c10::irange(tensor_indices.size())) {
|
|
stream << tensor_indices[i];
|
|
if (i < tensor_indices.size() - 1) stream << ", ";
|
|
}
|
|
stream << ')';
|
|
return stream;
|
|
}
|
|
|
|
// This mirrors `THPVariable_setitem` in torch/csrc/autograd/python_variable_indexing.cpp
|
|
// for "the assigned value is a Scalar" case
|
|
static inline void set_item(const Tensor& self, ArrayRef<TensorIndex> indices, const Scalar& v) {
|
|
Tensor value;
|
|
|
|
{
|
|
at::AutoDispatchBelowADInplaceOrView guard;
|
|
at::Device self_device = self.device();
|
|
|
|
// TODO: This qint special case looks very suspicious...
|
|
if (isQIntType(self.scalar_type())) {
|
|
value = at::indexing::scalarToTensor(v, device(kCPU).dtype(kFloat), at::Device(kCPU));
|
|
} else if (self_device.is_cuda()) {
|
|
value = at::indexing::scalarToTensor(v, self.options(), at::Device(kCPU));
|
|
} else {
|
|
value = at::indexing::scalarToTensor(v, self.options(), self_device);
|
|
}
|
|
}
|
|
|
|
set_item(self, indices, value);
|
|
}
|
|
|
|
} // namespace indexing
|
|
|
|
Tensor Tensor::index(ArrayRef<at::indexing::TensorIndex> indices) const {
|
|
TORCH_CHECK(!indices.empty(), "Passing an empty index list to Tensor::index() is not valid syntax");
|
|
OptionalDeviceGuard device_guard(device_of(*this));
|
|
return at::indexing::get_item(*this, indices);
|
|
}
|
|
Tensor Tensor::index(std::initializer_list<at::indexing::TensorIndex> indices) const {
|
|
return index(ArrayRef<at::indexing::TensorIndex>(indices));
|
|
}
|
|
|
|
Tensor & Tensor::index_put_(ArrayRef<at::indexing::TensorIndex> indices, Tensor const & rhs) {
|
|
TORCH_CHECK(!indices.empty(), "Passing an empty index list to Tensor::index_put_() is not valid syntax");
|
|
OptionalDeviceGuard device_guard(device_of(*this));
|
|
at::indexing::set_item(*this, indices, rhs);
|
|
return *this;
|
|
}
|
|
Tensor & Tensor::index_put_(ArrayRef<at::indexing::TensorIndex> indices, const Scalar& v) {
|
|
TORCH_CHECK(!indices.empty(), "Passing an empty index list to Tensor::index_put_() is not valid syntax");
|
|
OptionalDeviceGuard device_guard(device_of(*this));
|
|
at::indexing::set_item(*this, indices, v);
|
|
return *this;
|
|
}
|
|
Tensor & Tensor::index_put_(std::initializer_list<at::indexing::TensorIndex> indices, Tensor const & rhs) {
|
|
return index_put_(ArrayRef<at::indexing::TensorIndex>(indices), rhs);
|
|
}
|
|
Tensor & Tensor::index_put_(std::initializer_list<at::indexing::TensorIndex> indices, const Scalar& v) {
|
|
return index_put_(ArrayRef<at::indexing::TensorIndex>(indices), v);
|
|
}
|
|
|
|
} // namespace at
|