mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
Delete SymIntArrayRef wrapper struct (#84837)
Since we separated at::foo and at::foo_symint there is no benefit to trying to make initializer lists work in both cases. So we can get rid of the special different struct. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/84837 Approved by: https://github.com/kit1980
This commit is contained in:
committed by
PyTorch MergeBot
parent
8cdc0679b9
commit
9c78f599e4
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
@ -1 +1 @@
|
||||
e0dcc3171c8024ab288551d105fba24fbfae7332
|
||||
09be9870437684ba2da6741af3eb10126c04aede
|
||||
|
||||
@ -565,8 +565,6 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
IValue(c10::SymIntArrayRef v);
|
||||
|
||||
bool isSymInt() const {
|
||||
return Tag::SymInt == tag;
|
||||
}
|
||||
|
||||
@ -1999,7 +1999,6 @@ inline IValue::IValue(at::ArrayRef<T> v) : IValue(c10::List<T>()) {
|
||||
list.push_back(e);
|
||||
}
|
||||
}
|
||||
inline IValue::IValue(c10::SymIntArrayRef v) : IValue(at::ArrayRef<c10::SymInt>(v.data(), v.size())) {}
|
||||
template <class T, IValue::enable_if_ivalue_constructible<T>>
|
||||
inline IValue::IValue(const std::vector<T>& v) : IValue(c10::List<T>()) {
|
||||
auto list = to<c10::List<T>>();
|
||||
|
||||
@ -64,7 +64,7 @@ Tensor view(const Tensor& input, c10::SymIntArrayRef sym_size) {
|
||||
|
||||
Tensor reshape(const Tensor& input, IntArrayRef shape) {
|
||||
TORCH_CHECK(input.is_metal());
|
||||
return view(input, c10::SymIntArrayRef::fromIntArrayRef(shape));
|
||||
return view(input, c10::fromIntArrayRef(shape));
|
||||
}
|
||||
|
||||
Tensor flatten_using_ints(
|
||||
|
||||
@ -44,7 +44,7 @@ Tensor empty_strided_override(
|
||||
c10::optional<c10::Device> device,
|
||||
c10::optional<bool> pin_memory) {
|
||||
|
||||
return empty_override(SymIntArrayRef::fromIntArrayRef(size), dtype, layout, device, pin_memory, c10::nullopt);
|
||||
return empty_override(fromIntArrayRef(size), dtype, layout, device, pin_memory, c10::nullopt);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, ORT, m) {
|
||||
|
||||
@ -1,15 +1,3 @@
|
||||
// This file defines `SymIntArrayRef` which serves as the view onto
|
||||
// std::vector<SymInt>. This class is conceptually and mostly functionally
|
||||
// equivalent to ArrayRef<SymInt>.
|
||||
//
|
||||
// However, ArrayRef<SymInt> can't be used directly as it introduces ambiguity
|
||||
// in the following cases:
|
||||
// - a.expand({1, 2, 3}) matches two overloads:
|
||||
// 1. `at::Tensor Tensor::expand(c10::SymIntArrayRef size, bool implicit)`
|
||||
// 2. `at::Tensor Tensor::expand(at::IntArrayRef size, bool implicit)`
|
||||
// Introducing `SymIntArrayRef` allows to have a finer-grained control over
|
||||
// which overload will be used.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/SymInt.h>
|
||||
@ -23,196 +11,33 @@
|
||||
#include <vector>
|
||||
|
||||
namespace c10 {
|
||||
/// SymIntArrayRef - Represent a constant reference to an array (0 or more
|
||||
/// elements consecutively in memory), i.e. a start pointer and a length. It
|
||||
/// allows various APIs to take consecutive elements easily and conveniently.
|
||||
///
|
||||
/// This class does not own the underlying data, it is expected to be used in
|
||||
/// situations where the data resides in some other buffer, whose lifetime
|
||||
/// extends past that of the SymIntArrayRef. For this reason, it is not in
|
||||
/// general safe to store an SymIntArrayRef.
|
||||
///
|
||||
/// This is intended to be trivially copyable, so it should be passed by
|
||||
/// value.
|
||||
|
||||
class SymIntArrayRef final {
|
||||
public:
|
||||
using iterator = const c10::SymInt*;
|
||||
using const_iterator = const c10::SymInt*;
|
||||
using size_type = size_t;
|
||||
using value_type = c10::SymInt;
|
||||
|
||||
using reverse_iterator = std::reverse_iterator<iterator>;
|
||||
|
||||
private:
|
||||
ArrayRef<c10::SymInt> wrapped_symint_array_ref;
|
||||
|
||||
public:
|
||||
/// @name Constructors
|
||||
/// @{
|
||||
|
||||
/// Construct an empty SymIntArrayRef.
|
||||
/* implicit */ constexpr SymIntArrayRef() {}
|
||||
|
||||
/* implicit */ SymIntArrayRef(const std::vector<c10::SymInt>& Vec)
|
||||
: wrapped_symint_array_ref(Vec) {}
|
||||
|
||||
/// Construct an SymIntArrayRef from a pointer and length.
|
||||
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA SymIntArrayRef(
|
||||
const c10::SymInt* data,
|
||||
size_t length)
|
||||
: wrapped_symint_array_ref(data, length) {}
|
||||
|
||||
template <typename U>
|
||||
/* implicit */ SymIntArrayRef(
|
||||
const SmallVectorTemplateCommon<c10::SymInt, U>& Vec)
|
||||
: wrapped_symint_array_ref(Vec) {}
|
||||
|
||||
/// Construct an SymIntArrayRef from a range.
|
||||
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA SymIntArrayRef(
|
||||
const c10::SymInt* begin,
|
||||
const c10::SymInt* end)
|
||||
: wrapped_symint_array_ref(begin, end) {}
|
||||
|
||||
/// Construct an SymIntArrayRef from a C array.
|
||||
template <size_t N>
|
||||
/* implicit */ constexpr SymIntArrayRef(const c10::SymInt (&Arr)[N])
|
||||
: wrapped_symint_array_ref(Arr) {}
|
||||
|
||||
// Prefer using a more semantic constructor, like
|
||||
// fromIntArrayRefKnownNonNegative
|
||||
static SymIntArrayRef fromIntArrayRefUnchecked(IntArrayRef array_ref) {
|
||||
return SymIntArrayRef(
|
||||
reinterpret_cast<const SymInt*>(array_ref.data()), array_ref.size());
|
||||
}
|
||||
|
||||
static SymIntArrayRef fromIntArrayRefKnownNonNegative(IntArrayRef array_ref) {
|
||||
return fromIntArrayRefUnchecked(array_ref);
|
||||
}
|
||||
|
||||
static SymIntArrayRef fromIntArrayRef(IntArrayRef array_ref) {
|
||||
for (size_t i = 0; i < array_ref.size(); ++i) {
|
||||
TORCH_CHECK(
|
||||
SymInt::check_range(array_ref[i]),
|
||||
"IntArrayRef contains an int that cannot be represented as a SymInt: ",
|
||||
array_ref[i]);
|
||||
}
|
||||
return SymIntArrayRef(
|
||||
reinterpret_cast<const SymInt*>(array_ref.data()), array_ref.size());
|
||||
}
|
||||
|
||||
/// @}
|
||||
/// @name Simple Operations
|
||||
/// @{
|
||||
|
||||
constexpr iterator begin() const {
|
||||
return wrapped_symint_array_ref.begin();
|
||||
}
|
||||
constexpr iterator end() const {
|
||||
return wrapped_symint_array_ref.end();
|
||||
}
|
||||
|
||||
// These are actually the same as iterator, since SymIntArrayRef only
|
||||
// gives you const iterators.
|
||||
constexpr const_iterator cbegin() const {
|
||||
return wrapped_symint_array_ref.cbegin();
|
||||
}
|
||||
constexpr const_iterator cend() const {
|
||||
return wrapped_symint_array_ref.cend();
|
||||
}
|
||||
|
||||
/// empty - Check if the array is empty.
|
||||
constexpr bool empty() const {
|
||||
return size() == 0;
|
||||
}
|
||||
|
||||
constexpr const c10::SymInt* data() const {
|
||||
return wrapped_symint_array_ref.data();
|
||||
}
|
||||
|
||||
/// size - Get the array size.
|
||||
constexpr size_t size() const {
|
||||
return wrapped_symint_array_ref.size();
|
||||
}
|
||||
|
||||
/// front - Get the first element.
|
||||
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const c10::SymInt& front() const {
|
||||
return wrapped_symint_array_ref.front();
|
||||
}
|
||||
|
||||
/// back - Get the last element.
|
||||
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const c10::SymInt& back() const {
|
||||
return wrapped_symint_array_ref.back();
|
||||
}
|
||||
|
||||
/// equals - Check for element-wise equality.
|
||||
constexpr bool equals(SymIntArrayRef RHS) const {
|
||||
return this->wrapped_symint_array_ref.equals(RHS.wrapped_symint_array_ref);
|
||||
}
|
||||
|
||||
/// slice(n, m) - Take M elements of the array starting at element N
|
||||
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA SymIntArrayRef
|
||||
slice(size_t N, size_t M) const {
|
||||
return SymIntArrayRef(wrapped_symint_array_ref.data() + N, M);
|
||||
}
|
||||
|
||||
/// slice(n) - Chop off the first N elements of the array.
|
||||
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA SymIntArrayRef slice(size_t N) const {
|
||||
return slice(N, size() - N);
|
||||
}
|
||||
|
||||
/// @}
|
||||
/// @name Operator Overloads
|
||||
/// @{
|
||||
constexpr const c10::SymInt& operator[](size_t Index) const {
|
||||
return wrapped_symint_array_ref[Index];
|
||||
}
|
||||
|
||||
/// Vector compatibility
|
||||
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const c10::SymInt& at(size_t Index) const {
|
||||
return wrapped_symint_array_ref.at(Index);
|
||||
}
|
||||
|
||||
/// Disallow accidental assignment from a temporary.
|
||||
///
|
||||
/// The declaration here is extra complicated so that "arrayRef = {}"
|
||||
/// continues to select the move assignment operator.
|
||||
template <typename U>
|
||||
typename std::enable_if<std::is_same<U, c10::SymInt>::value, SymIntArrayRef>::
|
||||
type&
|
||||
operator=(U&& Temporary) = delete;
|
||||
|
||||
/// Disallow accidental assignment from a temporary.
|
||||
///
|
||||
/// The declaration here is extra complicated so that "arrayRef = {}"
|
||||
/// continues to select the move assignment operator.
|
||||
template <typename U>
|
||||
typename std::enable_if<std::is_same<U, c10::SymInt>::value, SymIntArrayRef>::
|
||||
type&
|
||||
operator=(std::initializer_list<U>) = delete;
|
||||
|
||||
/// @}
|
||||
/// @name Expensive Operations
|
||||
/// @{
|
||||
std::vector<c10::SymInt> vec() const {
|
||||
return wrapped_symint_array_ref.vec();
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(
|
||||
std::ostream& out,
|
||||
const SymIntArrayRef& list);
|
||||
/// @}
|
||||
};
|
||||
using SymIntArrayRef = ArrayRef<SymInt>;
|
||||
|
||||
TORCH_API at::IntArrayRef asIntArrayRefSlow(c10::SymIntArrayRef ar);
|
||||
TORCH_API at::IntArrayRef asIntArrayRefUnchecked(c10::SymIntArrayRef ar);
|
||||
TORCH_API c10::optional<at::IntArrayRef> asIntArrayRefSlowOpt(
|
||||
c10::SymIntArrayRef ar);
|
||||
|
||||
inline std::ostream& operator<<(
|
||||
std::ostream& out,
|
||||
const c10::SymIntArrayRef& list) {
|
||||
return out << list.wrapped_symint_array_ref;
|
||||
// Prefer using a more semantic constructor, like
|
||||
// fromIntArrayRefKnownNonNegative
|
||||
inline SymIntArrayRef fromIntArrayRefUnchecked(IntArrayRef array_ref) {
|
||||
return SymIntArrayRef(
|
||||
reinterpret_cast<const SymInt*>(array_ref.data()), array_ref.size());
|
||||
}
|
||||
|
||||
inline SymIntArrayRef fromIntArrayRefKnownNonNegative(IntArrayRef array_ref) {
|
||||
return fromIntArrayRefUnchecked(array_ref);
|
||||
}
|
||||
|
||||
inline SymIntArrayRef fromIntArrayRef(IntArrayRef array_ref) {
|
||||
for (size_t i = 0; i < array_ref.size(); ++i) {
|
||||
TORCH_CHECK(
|
||||
SymInt::check_range(array_ref[i]),
|
||||
"IntArrayRef contains an int that cannot be represented as a SymInt: ",
|
||||
array_ref[i]);
|
||||
}
|
||||
return SymIntArrayRef(
|
||||
reinterpret_cast<const SymInt*>(array_ref.data()), array_ref.size());
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
@ -603,7 +603,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
return sym_sizes_custom();
|
||||
}
|
||||
// Sizes guaranteed to be non-negative, so unchecked cast is OK
|
||||
return c10::SymIntArrayRef::fromIntArrayRefKnownNonNegative(
|
||||
return c10::fromIntArrayRefKnownNonNegative(
|
||||
sizes_and_strides_.sizes_arrayref());
|
||||
}
|
||||
|
||||
@ -620,8 +620,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
return extra_meta_->sizes_;
|
||||
} else {
|
||||
// Sizes guaranteed to be non-negative, so unchecked cast is OK
|
||||
return c10::SymIntArrayRef::fromIntArrayRefKnownNonNegative(
|
||||
sizes_default());
|
||||
return c10::fromIntArrayRefKnownNonNegative(sizes_default());
|
||||
}
|
||||
}
|
||||
|
||||
@ -733,8 +732,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) {
|
||||
return sym_strides_custom();
|
||||
}
|
||||
// strides guaranteed to be non-negative, so unchecked cast is OK
|
||||
return c10::SymIntArrayRef::fromIntArrayRefUnchecked(strides_default());
|
||||
return c10::fromIntArrayRefKnownNonNegative(strides_default());
|
||||
}
|
||||
|
||||
IntArrayRef strides_default() const {
|
||||
@ -748,8 +746,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
if (has_symbolic_sizes_strides_) {
|
||||
return extra_meta_->strides_;
|
||||
} else {
|
||||
return c10::SymIntArrayRef::fromIntArrayRefKnownNonNegative(
|
||||
strides_default());
|
||||
return c10::fromIntArrayRefKnownNonNegative(strides_default());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -103,7 +103,8 @@ TEST_F(Quantization, QuantDequantUInt8_NLC) {
|
||||
parseIR(graph_string, &*graph);
|
||||
|
||||
auto x = 2 * at::rand({1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
||||
x.unsafeGetTensorImpl()->set_sizes_and_strides({1, 2, 2}, {4, 1, 2});
|
||||
x.unsafeGetTensorImpl()->set_sizes_and_strides(
|
||||
std::initializer_list<int64_t>{1, 2, 2}, {4, 1, 2});
|
||||
auto q = at::quantize_per_tensor(x, 0.1f, 122, at::kQUInt8);
|
||||
auto y_expected = at::dequantize(q);
|
||||
TensorExprKernel k(graph);
|
||||
|
||||
@ -157,7 +157,7 @@ c10::SymIntArrayRef LTCTensorImpl::sym_sizes_custom() const {
|
||||
return c10::SymIntArrayRef(sym_sizes_->data(), sym_sizes_->size());
|
||||
}
|
||||
|
||||
return c10::SymIntArrayRef::fromIntArrayRef(sizes_custom());
|
||||
return c10::fromIntArrayRef(sizes_custom());
|
||||
}
|
||||
|
||||
void LTCTensorImpl::setup_size_properties() {
|
||||
|
||||
@ -308,7 +308,7 @@ at::Tensor LazyNativeFunctions::empty_strided(
|
||||
c10::optional<bool> pin_memory) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
at::Tensor t = empty_symint(
|
||||
c10::SymIntArrayRef::fromIntArrayRef(size),
|
||||
c10::fromIntArrayRef(size),
|
||||
dtype,
|
||||
layout,
|
||||
device,
|
||||
@ -410,7 +410,7 @@ at::Tensor LazyNativeFunctions::_unsafe_view(
|
||||
at::IntArrayRef size) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
return LazyNativeFunctions::view_copy_symint(
|
||||
self, c10::SymIntArrayRef::fromIntArrayRef(size));
|
||||
self, c10::fromIntArrayRef(size));
|
||||
}
|
||||
|
||||
// This is needed by the torch.tensor constructor.
|
||||
|
||||
@ -339,7 +339,7 @@ Check this module for more information.
|
||||
elif goal.type == BaseCType(symIntArrayRefT):
|
||||
try:
|
||||
r = direct_solve(NamedCType(goal.name, BaseCType(intArrayRefT)))
|
||||
return f"c10::SymIntArrayRef::fromIntArrayRef({r})"
|
||||
return f"c10::fromIntArrayRef({r})"
|
||||
except UnsatError:
|
||||
return direct_solve(NamedCType(goal.name, longSymVec_ctype))
|
||||
elif goal.type == BaseCType(SymIntT):
|
||||
|
||||
@ -89,7 +89,7 @@ at::Tensor view_copy(const at::Tensor & self, at::IntArrayRef size) {
|
||||
if (!at::detail::computeStride(self.sizes(), self.strides(), shape).has_value()) {
|
||||
return self.reshape(size);
|
||||
} else {
|
||||
auto output = at::_ops::view::call(self, c10::SymIntArrayRef::fromIntArrayRef(size));
|
||||
auto output = at::_ops::view::call(self, c10::fromIntArrayRef(size));
|
||||
return output.clone();
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user