Delete SymIntArrayRef wrapper struct (#84837)

Since we separated at::foo and at::foo_symint there is no benefit
to trying to make initializer lists work in both cases.  So we can
get rid of the special different struct.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84837
Approved by: https://github.com/kit1980
This commit is contained in:
Edward Z. Yang
2022-09-12 09:25:09 -07:00
committed by PyTorch MergeBot
parent 8cdc0679b9
commit 9c78f599e4
12 changed files with 35 additions and 215 deletions

View File

@ -1 +1 @@
e0dcc3171c8024ab288551d105fba24fbfae7332
09be9870437684ba2da6741af3eb10126c04aede

View File

@ -565,8 +565,6 @@ public:
}
}
IValue(c10::SymIntArrayRef v);
bool isSymInt() const {
return Tag::SymInt == tag;
}

View File

@ -1999,7 +1999,6 @@ inline IValue::IValue(at::ArrayRef<T> v) : IValue(c10::List<T>()) {
list.push_back(e);
}
}
inline IValue::IValue(c10::SymIntArrayRef v) : IValue(at::ArrayRef<c10::SymInt>(v.data(), v.size())) {}
template <class T, IValue::enable_if_ivalue_constructible<T>>
inline IValue::IValue(const std::vector<T>& v) : IValue(c10::List<T>()) {
auto list = to<c10::List<T>>();

View File

@ -64,7 +64,7 @@ Tensor view(const Tensor& input, c10::SymIntArrayRef sym_size) {
Tensor reshape(const Tensor& input, IntArrayRef shape) {
TORCH_CHECK(input.is_metal());
return view(input, c10::SymIntArrayRef::fromIntArrayRef(shape));
return view(input, c10::fromIntArrayRef(shape));
}
Tensor flatten_using_ints(

View File

@ -44,7 +44,7 @@ Tensor empty_strided_override(
c10::optional<c10::Device> device,
c10::optional<bool> pin_memory) {
return empty_override(SymIntArrayRef::fromIntArrayRef(size), dtype, layout, device, pin_memory, c10::nullopt);
return empty_override(fromIntArrayRef(size), dtype, layout, device, pin_memory, c10::nullopt);
}
TORCH_LIBRARY_IMPL(aten, ORT, m) {

View File

@ -1,15 +1,3 @@
// This file defines `SymIntArrayRef` which serves as the view onto
// std::vector<SymInt>. This class is conceptually and mostly functionally
// equivalent to ArrayRef<SymInt>.
//
// However, ArrayRef<SymInt> can't be used directly as it introduces ambiguity
// in the following cases:
// - a.expand({1, 2, 3}) matches two overloads:
// 1. `at::Tensor Tensor::expand(c10::SymIntArrayRef size, bool implicit)`
// 2. `at::Tensor Tensor::expand(at::IntArrayRef size, bool implicit)`
// Introducing `SymIntArrayRef` allows to have a finer-grained control over
// which overload will be used.
#pragma once
#include <c10/core/SymInt.h>
@ -23,196 +11,33 @@
#include <vector>
namespace c10 {
/// SymIntArrayRef - Represent a constant reference to an array (0 or more
/// elements consecutively in memory), i.e. a start pointer and a length. It
/// allows various APIs to take consecutive elements easily and conveniently.
///
/// This class does not own the underlying data, it is expected to be used in
/// situations where the data resides in some other buffer, whose lifetime
/// extends past that of the SymIntArrayRef. For this reason, it is not in
/// general safe to store an SymIntArrayRef.
///
/// This is intended to be trivially copyable, so it should be passed by
/// value.
class SymIntArrayRef final {
public:
using iterator = const c10::SymInt*;
using const_iterator = const c10::SymInt*;
using size_type = size_t;
using value_type = c10::SymInt;
using reverse_iterator = std::reverse_iterator<iterator>;
private:
ArrayRef<c10::SymInt> wrapped_symint_array_ref;
public:
/// @name Constructors
/// @{
/// Construct an empty SymIntArrayRef.
/* implicit */ constexpr SymIntArrayRef() {}
/* implicit */ SymIntArrayRef(const std::vector<c10::SymInt>& Vec)
: wrapped_symint_array_ref(Vec) {}
/// Construct an SymIntArrayRef from a pointer and length.
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA SymIntArrayRef(
const c10::SymInt* data,
size_t length)
: wrapped_symint_array_ref(data, length) {}
template <typename U>
/* implicit */ SymIntArrayRef(
const SmallVectorTemplateCommon<c10::SymInt, U>& Vec)
: wrapped_symint_array_ref(Vec) {}
/// Construct an SymIntArrayRef from a range.
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA SymIntArrayRef(
const c10::SymInt* begin,
const c10::SymInt* end)
: wrapped_symint_array_ref(begin, end) {}
/// Construct an SymIntArrayRef from a C array.
template <size_t N>
/* implicit */ constexpr SymIntArrayRef(const c10::SymInt (&Arr)[N])
: wrapped_symint_array_ref(Arr) {}
// Prefer using a more semantic constructor, like
// fromIntArrayRefKnownNonNegative
static SymIntArrayRef fromIntArrayRefUnchecked(IntArrayRef array_ref) {
return SymIntArrayRef(
reinterpret_cast<const SymInt*>(array_ref.data()), array_ref.size());
}
static SymIntArrayRef fromIntArrayRefKnownNonNegative(IntArrayRef array_ref) {
return fromIntArrayRefUnchecked(array_ref);
}
static SymIntArrayRef fromIntArrayRef(IntArrayRef array_ref) {
for (size_t i = 0; i < array_ref.size(); ++i) {
TORCH_CHECK(
SymInt::check_range(array_ref[i]),
"IntArrayRef contains an int that cannot be represented as a SymInt: ",
array_ref[i]);
}
return SymIntArrayRef(
reinterpret_cast<const SymInt*>(array_ref.data()), array_ref.size());
}
/// @}
/// @name Simple Operations
/// @{
constexpr iterator begin() const {
return wrapped_symint_array_ref.begin();
}
constexpr iterator end() const {
return wrapped_symint_array_ref.end();
}
// These are actually the same as iterator, since SymIntArrayRef only
// gives you const iterators.
constexpr const_iterator cbegin() const {
return wrapped_symint_array_ref.cbegin();
}
constexpr const_iterator cend() const {
return wrapped_symint_array_ref.cend();
}
/// empty - Check if the array is empty.
constexpr bool empty() const {
return size() == 0;
}
constexpr const c10::SymInt* data() const {
return wrapped_symint_array_ref.data();
}
/// size - Get the array size.
constexpr size_t size() const {
return wrapped_symint_array_ref.size();
}
/// front - Get the first element.
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const c10::SymInt& front() const {
return wrapped_symint_array_ref.front();
}
/// back - Get the last element.
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const c10::SymInt& back() const {
return wrapped_symint_array_ref.back();
}
/// equals - Check for element-wise equality.
constexpr bool equals(SymIntArrayRef RHS) const {
return this->wrapped_symint_array_ref.equals(RHS.wrapped_symint_array_ref);
}
/// slice(n, m) - Take M elements of the array starting at element N
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA SymIntArrayRef
slice(size_t N, size_t M) const {
return SymIntArrayRef(wrapped_symint_array_ref.data() + N, M);
}
/// slice(n) - Chop off the first N elements of the array.
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA SymIntArrayRef slice(size_t N) const {
return slice(N, size() - N);
}
/// @}
/// @name Operator Overloads
/// @{
constexpr const c10::SymInt& operator[](size_t Index) const {
return wrapped_symint_array_ref[Index];
}
/// Vector compatibility
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const c10::SymInt& at(size_t Index) const {
return wrapped_symint_array_ref.at(Index);
}
/// Disallow accidental assignment from a temporary.
///
/// The declaration here is extra complicated so that "arrayRef = {}"
/// continues to select the move assignment operator.
template <typename U>
typename std::enable_if<std::is_same<U, c10::SymInt>::value, SymIntArrayRef>::
type&
operator=(U&& Temporary) = delete;
/// Disallow accidental assignment from a temporary.
///
/// The declaration here is extra complicated so that "arrayRef = {}"
/// continues to select the move assignment operator.
template <typename U>
typename std::enable_if<std::is_same<U, c10::SymInt>::value, SymIntArrayRef>::
type&
operator=(std::initializer_list<U>) = delete;
/// @}
/// @name Expensive Operations
/// @{
std::vector<c10::SymInt> vec() const {
return wrapped_symint_array_ref.vec();
}
friend std::ostream& operator<<(
std::ostream& out,
const SymIntArrayRef& list);
/// @}
};
using SymIntArrayRef = ArrayRef<SymInt>;
TORCH_API at::IntArrayRef asIntArrayRefSlow(c10::SymIntArrayRef ar);
TORCH_API at::IntArrayRef asIntArrayRefUnchecked(c10::SymIntArrayRef ar);
TORCH_API c10::optional<at::IntArrayRef> asIntArrayRefSlowOpt(
c10::SymIntArrayRef ar);
inline std::ostream& operator<<(
std::ostream& out,
const c10::SymIntArrayRef& list) {
return out << list.wrapped_symint_array_ref;
// Prefer using a more semantic constructor, like
// fromIntArrayRefKnownNonNegative
inline SymIntArrayRef fromIntArrayRefUnchecked(IntArrayRef array_ref) {
return SymIntArrayRef(
reinterpret_cast<const SymInt*>(array_ref.data()), array_ref.size());
}
inline SymIntArrayRef fromIntArrayRefKnownNonNegative(IntArrayRef array_ref) {
return fromIntArrayRefUnchecked(array_ref);
}
inline SymIntArrayRef fromIntArrayRef(IntArrayRef array_ref) {
for (size_t i = 0; i < array_ref.size(); ++i) {
TORCH_CHECK(
SymInt::check_range(array_ref[i]),
"IntArrayRef contains an int that cannot be represented as a SymInt: ",
array_ref[i]);
}
return SymIntArrayRef(
reinterpret_cast<const SymInt*>(array_ref.data()), array_ref.size());
}
} // namespace c10

View File

@ -603,7 +603,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
return sym_sizes_custom();
}
// Sizes guaranteed to be non-negative, so unchecked cast is OK
return c10::SymIntArrayRef::fromIntArrayRefKnownNonNegative(
return c10::fromIntArrayRefKnownNonNegative(
sizes_and_strides_.sizes_arrayref());
}
@ -620,8 +620,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
return extra_meta_->sizes_;
} else {
// Sizes guaranteed to be non-negative, so unchecked cast is OK
return c10::SymIntArrayRef::fromIntArrayRefKnownNonNegative(
sizes_default());
return c10::fromIntArrayRefKnownNonNegative(sizes_default());
}
}
@ -733,8 +732,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) {
return sym_strides_custom();
}
// strides guaranteed to be non-negative, so unchecked cast is OK
return c10::SymIntArrayRef::fromIntArrayRefUnchecked(strides_default());
return c10::fromIntArrayRefKnownNonNegative(strides_default());
}
IntArrayRef strides_default() const {
@ -748,8 +746,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
if (has_symbolic_sizes_strides_) {
return extra_meta_->strides_;
} else {
return c10::SymIntArrayRef::fromIntArrayRefKnownNonNegative(
strides_default());
return c10::fromIntArrayRefKnownNonNegative(strides_default());
}
}

View File

@ -103,7 +103,8 @@ TEST_F(Quantization, QuantDequantUInt8_NLC) {
parseIR(graph_string, &*graph);
auto x = 2 * at::rand({1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
x.unsafeGetTensorImpl()->set_sizes_and_strides({1, 2, 2}, {4, 1, 2});
x.unsafeGetTensorImpl()->set_sizes_and_strides(
std::initializer_list<int64_t>{1, 2, 2}, {4, 1, 2});
auto q = at::quantize_per_tensor(x, 0.1f, 122, at::kQUInt8);
auto y_expected = at::dequantize(q);
TensorExprKernel k(graph);

View File

@ -157,7 +157,7 @@ c10::SymIntArrayRef LTCTensorImpl::sym_sizes_custom() const {
return c10::SymIntArrayRef(sym_sizes_->data(), sym_sizes_->size());
}
return c10::SymIntArrayRef::fromIntArrayRef(sizes_custom());
return c10::fromIntArrayRef(sizes_custom());
}
void LTCTensorImpl::setup_size_properties() {

View File

@ -308,7 +308,7 @@ at::Tensor LazyNativeFunctions::empty_strided(
c10::optional<bool> pin_memory) {
TORCH_LAZY_FN_COUNTER("lazy::");
at::Tensor t = empty_symint(
c10::SymIntArrayRef::fromIntArrayRef(size),
c10::fromIntArrayRef(size),
dtype,
layout,
device,
@ -410,7 +410,7 @@ at::Tensor LazyNativeFunctions::_unsafe_view(
at::IntArrayRef size) {
TORCH_LAZY_FN_COUNTER("lazy::");
return LazyNativeFunctions::view_copy_symint(
self, c10::SymIntArrayRef::fromIntArrayRef(size));
self, c10::fromIntArrayRef(size));
}
// This is needed by the torch.tensor constructor.

View File

@ -339,7 +339,7 @@ Check this module for more information.
elif goal.type == BaseCType(symIntArrayRefT):
try:
r = direct_solve(NamedCType(goal.name, BaseCType(intArrayRefT)))
return f"c10::SymIntArrayRef::fromIntArrayRef({r})"
return f"c10::fromIntArrayRef({r})"
except UnsatError:
return direct_solve(NamedCType(goal.name, longSymVec_ctype))
elif goal.type == BaseCType(SymIntT):

View File

@ -89,7 +89,7 @@ at::Tensor view_copy(const at::Tensor & self, at::IntArrayRef size) {
if (!at::detail::computeStride(self.sizes(), self.strides(), shape).has_value()) {
return self.reshape(size);
} else {
auto output = at::_ops::view::call(self, c10::SymIntArrayRef::fromIntArrayRef(size));
auto output = at::_ops::view::call(self, c10::fromIntArrayRef(size));
return output.clone();
}
}