mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
remove expext_size c++ bindings and usages Pull Request resolved: https://github.com/pytorch/pytorch/pull/164889 Approved by: https://github.com/mlazos ghstack dependencies: #164884, #164885, #164886, #164887, #164888
559 lines
17 KiB
C++
559 lines
17 KiB
C++
#pragma once
|
|
|
|
#include <c10/core/SymBool.h>
|
|
#include <c10/core/SymNodeImpl.h>
|
|
#include <c10/macros/Export.h>
|
|
#include <c10/macros/Macros.h>
|
|
#include <c10/util/Exception.h>
|
|
#include <c10/util/Optional.h>
|
|
|
|
#include <algorithm>
|
|
#include <cstdint>
|
|
#include <iterator>
|
|
#include <numeric>
|
|
#include <optional>
|
|
#include <ostream>
|
|
#include <type_traits>
|
|
|
|
namespace c10 {
|
|
|
|
class SymFloat;
|
|
|
|
// SymInt represents either a regular int64_t, or a symbolic integer
|
|
// (represented in a type erased way as SymNode). The intention is for SymInt
|
|
// to represent symbolic sizes that arise when doing shape computation in
|
|
// operator kernels. This allows for tracing through programs without baking in
|
|
// concrete sizes into kernel calls.
|
|
//
|
|
// SymInt has an API equivalent to int64_t. In particular, it is a value type.
|
|
// Internally, SymInt is represented in a clever packed way, so that it only
|
|
// occupies one word of space; but morally, it is a union between an int64_t
|
|
// and an intrusive pointer to SymNodeImpl.
|
|
//
|
|
// Invariant: the referenced SymNodeImpl is guaranteed to be a SymNode where
|
|
// is_int() returns true
|
|
|
|
class C10_API SymInt {
|
|
public:
|
|
enum Unchecked {
|
|
UNCHECKED,
|
|
};
|
|
|
|
/*implicit*/ SymInt(int64_t d) : data_(d) {
|
|
if (is_heap_allocated()) {
|
|
// Large negative number, heap allocate it
|
|
promote_to_negative();
|
|
}
|
|
}
|
|
SymInt() : data_(0) {}
|
|
SymInt(SymNode n);
|
|
|
|
// unchecked c-tor accepting raw `data_`
|
|
// One appropriate use for this is when you are constructing a symint
|
|
// in a situation where you know it is non-negative (or, if it is negative,
|
|
// the negative value is -1; i.e., not user controlled)
|
|
SymInt(Unchecked /*unused*/, int64_t d) : data_(d) {}
|
|
|
|
// TODO: these implementations are not optimal because they allocate a
|
|
// temporary and then use the move constructor/assignment
|
|
SymInt(const SymInt& s) : data_(0) {
|
|
if (s.is_heap_allocated()) {
|
|
*this = SymInt(s.toSymNode());
|
|
} else {
|
|
data_ = s.data_;
|
|
}
|
|
}
|
|
SymInt(SymInt&& s) noexcept : data_(s.data_) {
|
|
s.data_ = 0;
|
|
}
|
|
|
|
SymInt& operator=(const SymInt& s) {
|
|
if (this != &s) {
|
|
if (s.is_heap_allocated()) {
|
|
*this = SymInt(s.toSymNode());
|
|
} else {
|
|
data_ = s.data_;
|
|
}
|
|
}
|
|
return *this;
|
|
}
|
|
SymInt& operator=(SymInt&& s) noexcept {
|
|
if (this != &s) {
|
|
release_(); // release the current SymNode if any
|
|
data_ = s.data_;
|
|
if (s.is_heap_allocated())
|
|
s.data_ = 0;
|
|
};
|
|
return *this;
|
|
}
|
|
|
|
SymNodeImpl* toSymNodeImplUnowned() const {
|
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(is_heap_allocated());
|
|
uint64_t unextended_bits = static_cast<uint64_t>(data_) & ~MASK;
|
|
uint64_t sign_bit_mask = 1ULL << (62 - 1);
|
|
// https://stackoverflow.com/questions/42534749/signed-extension-from-24-bit-to-32-bit-in-c
|
|
uint64_t extended_bits = (unextended_bits ^ sign_bit_mask) - sign_bit_mask;
|
|
return static_cast<SymNodeImpl*>(
|
|
// NOLINTNEXTLINE(performance-no-int-to-ptr, bugprone*)
|
|
reinterpret_cast<void*>(static_cast<uintptr_t>(extended_bits)));
|
|
}
|
|
|
|
void release_() {
|
|
if (is_heap_allocated()) {
|
|
SymNode::reclaim(toSymNodeImplUnowned()); // steal
|
|
}
|
|
}
|
|
|
|
SymNodeImpl* release() && {
|
|
#ifndef C10_MOBILE
|
|
TORCH_INTERNAL_ASSERT(is_heap_allocated());
|
|
auto* r = toSymNodeImplUnowned();
|
|
data_ = 0; // transfer ownership
|
|
return r;
|
|
#else
|
|
TORCH_INTERNAL_ASSERT(false);
|
|
#endif
|
|
}
|
|
|
|
// Only valid if is_heap_allocated()
|
|
SymNode toSymNode() const;
|
|
|
|
// Guaranteed to return a SymNode, wrapping using base if necessary
|
|
SymNode wrap_node(const SymNode& base) const;
|
|
|
|
~SymInt() {
|
|
release_();
|
|
}
|
|
|
|
// Require the int to be non-symbolic, and if it is symbolic raise an
|
|
// error. This is safe to use for C++ code that doesn't work for symbolic
|
|
// shapes, and you don't have time to fix it immediately, as if we
|
|
// try to trigger the path in C++ you'll appropriately get an error
|
|
int64_t expect_int() const {
|
|
if (auto r = maybe_as_int()) {
|
|
return *r;
|
|
}
|
|
TORCH_CHECK_ALWAYS_SHOW_CPP_STACKTRACE(
|
|
false, "when unpacking SymInt, expected int but got ", *this);
|
|
}
|
|
|
|
// Test if we have a hint for this int (e.g., guard_int would work).
|
|
// Most of the time this is true; it is only false when you have
|
|
// an unbacked SymInt.
|
|
bool has_hint() const;
|
|
|
|
// Insert a guard for the int to be its concrete value, and then return
|
|
// that value. This operation always works, even if the int is symbolic,
|
|
// so long as we know what the underlying value is (e.g., this won't work
|
|
// if you call it on the size of nonzero output). Don't blindly put this
|
|
// everywhere; you can cause overspecialization of PyTorch programs with
|
|
// this method.
|
|
//
|
|
// It should be called as guard_int(__FILE__, __LINE__). The file and line
|
|
// number can be used to diagnose overspecialization.
|
|
int64_t guard_int(const char* file, int64_t line) const;
|
|
|
|
// Distinguish actual symbolic values from constants stored on the heap
|
|
bool is_symbolic() const {
|
|
return is_heap_allocated() &&
|
|
!toSymNodeImplUnowned()->constant_int().has_value();
|
|
}
|
|
|
|
// N.B. It's important to keep this definition in the header
|
|
// as we expect if checks to be folded for mobile builds
|
|
// where `is_heap_allocated` is always false and optimize dead code paths
|
|
C10_ALWAYS_INLINE bool is_heap_allocated() const {
|
|
#ifdef C10_MOBILE
|
|
return false;
|
|
#else
|
|
return !check_range(data_);
|
|
#endif
|
|
}
|
|
|
|
SymInt operator+(const SymInt& sci) const {
|
|
if (auto ma = maybe_as_int()) {
|
|
if (auto mb = sci.maybe_as_int()) {
|
|
return SymInt(*ma + *mb);
|
|
}
|
|
}
|
|
return operator_add_slow_path(sci);
|
|
}
|
|
|
|
SymInt operator-(const SymInt& sci) const {
|
|
if (auto ma = maybe_as_int()) {
|
|
if (auto mb = sci.maybe_as_int()) {
|
|
return SymInt(*ma - *mb);
|
|
}
|
|
}
|
|
return operator_sub_slow_path(sci);
|
|
}
|
|
|
|
SymInt operator*(const SymInt& sci) const {
|
|
if (auto ma = maybe_as_int()) {
|
|
if (auto mb = sci.maybe_as_int()) {
|
|
return SymInt(*ma * *mb);
|
|
}
|
|
}
|
|
return operator_mul_slow_path(sci);
|
|
}
|
|
|
|
SymInt operator/(const SymInt& sci) const {
|
|
if (auto ma = maybe_as_int()) {
|
|
if (auto mb = sci.maybe_as_int()) {
|
|
return SymInt(*ma / *mb);
|
|
}
|
|
}
|
|
return operator_div_slow_path(sci);
|
|
}
|
|
|
|
SymInt operator%(const SymInt& sci) const {
|
|
if (auto ma = maybe_as_int()) {
|
|
if (auto mb = sci.maybe_as_int()) {
|
|
return SymInt(*ma % *mb);
|
|
}
|
|
}
|
|
return operator_mod_slow_path(sci);
|
|
}
|
|
|
|
void operator*=(const SymInt& sci) {
|
|
if (auto ma = maybe_as_int()) {
|
|
if (auto mb = sci.maybe_as_int()) {
|
|
*this = SymInt(*ma * *mb);
|
|
return;
|
|
}
|
|
}
|
|
operator_imul_slow_path(sci);
|
|
}
|
|
|
|
void operator+=(const SymInt& sci) {
|
|
if (auto ma = maybe_as_int()) {
|
|
if (auto mb = sci.maybe_as_int()) {
|
|
*this = SymInt(*ma + *mb);
|
|
return;
|
|
}
|
|
}
|
|
operator_iadd_slow_path(sci);
|
|
}
|
|
|
|
void operator/=(const SymInt& sci) {
|
|
if (auto ma = maybe_as_int()) {
|
|
if (auto mb = sci.maybe_as_int()) {
|
|
*this = SymInt(*ma / *mb);
|
|
return;
|
|
}
|
|
}
|
|
operator_idiv_slow_path(sci);
|
|
}
|
|
|
|
SymInt clone() const;
|
|
|
|
SymBool sym_eq(const SymInt& sci) const {
|
|
if (auto ma = maybe_as_int()) {
|
|
if (auto mb = sci.maybe_as_int()) {
|
|
return SymBool(*ma == *mb);
|
|
}
|
|
}
|
|
return sym_eq_slow_path(sci);
|
|
}
|
|
|
|
SymBool sym_ne(const SymInt& sci) const {
|
|
if (auto ma = maybe_as_int()) {
|
|
if (auto mb = sci.maybe_as_int()) {
|
|
return SymBool(*ma != *mb);
|
|
}
|
|
}
|
|
return sym_ne_slow_path(sci);
|
|
}
|
|
|
|
SymBool sym_lt(const SymInt& sci) const {
|
|
if (auto ma = maybe_as_int()) {
|
|
if (auto mb = sci.maybe_as_int()) {
|
|
return SymBool(*ma < *mb);
|
|
}
|
|
}
|
|
return sym_lt_slow_path(sci);
|
|
}
|
|
|
|
SymBool sym_le(const SymInt& sci) const {
|
|
if (auto ma = maybe_as_int()) {
|
|
if (auto mb = sci.maybe_as_int()) {
|
|
return SymBool(*ma <= *mb);
|
|
}
|
|
}
|
|
return sym_le_slow_path(sci);
|
|
}
|
|
|
|
SymBool sym_gt(const SymInt& sci) const {
|
|
if (auto ma = maybe_as_int()) {
|
|
if (auto mb = sci.maybe_as_int()) {
|
|
return SymBool(*ma > *mb);
|
|
}
|
|
}
|
|
return sym_gt_slow_path(sci);
|
|
}
|
|
|
|
SymBool sym_ge(const SymInt& sci) const {
|
|
if (auto ma = maybe_as_int()) {
|
|
if (auto mb = sci.maybe_as_int()) {
|
|
return SymBool(*ma >= *mb);
|
|
}
|
|
}
|
|
return sym_ge_slow_path(sci);
|
|
}
|
|
|
|
bool operator==(const SymInt& o) const {
|
|
return sym_eq(o).guard_bool(__FILE__, __LINE__);
|
|
}
|
|
bool operator!=(const SymInt& o) const {
|
|
return sym_ne(o).guard_bool(__FILE__, __LINE__);
|
|
}
|
|
bool operator<(const SymInt& o) const {
|
|
return sym_lt(o).guard_bool(__FILE__, __LINE__);
|
|
}
|
|
bool operator<=(const SymInt& o) const {
|
|
return sym_le(o).guard_bool(__FILE__, __LINE__);
|
|
}
|
|
bool operator>(const SymInt& o) const {
|
|
return sym_gt(o).guard_bool(__FILE__, __LINE__);
|
|
}
|
|
bool operator>=(const SymInt& o) const {
|
|
return sym_ge(o).guard_bool(__FILE__, __LINE__);
|
|
}
|
|
|
|
SymInt min(const SymInt& sci) const {
|
|
if (auto ma = maybe_as_int()) {
|
|
if (auto mb = sci.maybe_as_int()) {
|
|
return SymInt(std::min(*ma, *mb));
|
|
}
|
|
}
|
|
return min_slow_path(sci);
|
|
}
|
|
|
|
SymInt max(const SymInt& sci) const {
|
|
if (auto ma = maybe_as_int()) {
|
|
if (auto mb = sci.maybe_as_int()) {
|
|
return SymInt(std::max(*ma, *mb));
|
|
}
|
|
}
|
|
return max_slow_path(sci);
|
|
}
|
|
|
|
// If both are symbolic, this checks if
|
|
// they share the same node.
|
|
// If both are not symbolic this just checks normal equality.
|
|
bool is_same(const SymInt& other) const;
|
|
|
|
operator SymFloat() const;
|
|
|
|
void unsafe_set_data(size_t nbytes) {
|
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!is_heap_allocated());
|
|
data_ = static_cast<int64_t>(nbytes);
|
|
}
|
|
|
|
// Don't use this. Prefer maybe_as_int instead
|
|
int64_t as_int_unchecked() const {
|
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!is_heap_allocated());
|
|
return data_;
|
|
}
|
|
|
|
std::optional<int64_t> maybe_as_int() const {
|
|
if (!is_heap_allocated()) {
|
|
return data_;
|
|
}
|
|
return maybe_as_int_slow_path();
|
|
}
|
|
|
|
// Return whether the integer is directly coercible to a SymInt
|
|
// without requiring heap allocation. You don't need to use this
|
|
// to check if you can pass an integer to SymInt; this is guaranteed
|
|
// to work (it just might heap allocate!)
|
|
static bool check_range(int64_t i) {
|
|
return i > MAX_UNREPRESENTABLE_INT;
|
|
}
|
|
|
|
// Return the min representable integer as a SymInt without
|
|
// heap allocation. For quantities that count bytes (or larger),
|
|
// this is still much larger than you need, so you may consider
|
|
// using this as a more efficient version of MIN_INT
|
|
static constexpr int64_t min_representable_int() {
|
|
return MAX_UNREPRESENTABLE_INT + 1;
|
|
}
|
|
|
|
private:
|
|
void promote_to_negative();
|
|
SymInt operator_add_slow_path(const SymInt& sci) const;
|
|
SymInt operator_sub_slow_path(const SymInt& sci) const;
|
|
SymInt operator_mul_slow_path(const SymInt& sci) const;
|
|
SymInt operator_div_slow_path(const SymInt& sci) const;
|
|
SymInt operator_mod_slow_path(const SymInt& sci) const;
|
|
void operator_imul_slow_path(const SymInt& sci);
|
|
void operator_iadd_slow_path(const SymInt& sci);
|
|
void operator_idiv_slow_path(const SymInt& sci);
|
|
SymBool sym_eq_slow_path(const SymInt& sci) const;
|
|
SymBool sym_ne_slow_path(const SymInt& sci) const;
|
|
SymBool sym_lt_slow_path(const SymInt& sci) const;
|
|
SymBool sym_le_slow_path(const SymInt& sci) const;
|
|
SymBool sym_gt_slow_path(const SymInt& sci) const;
|
|
SymBool sym_ge_slow_path(const SymInt& sci) const;
|
|
|
|
SymInt min_slow_path(const SymInt& sci) const;
|
|
SymInt max_slow_path(const SymInt& sci) const;
|
|
|
|
std::optional<int64_t> maybe_as_int_slow_path() const;
|
|
|
|
// Constraints on the internal representation:
|
|
//
|
|
// - Should represent positive and small negative ints
|
|
// - No conversion necessary for operations on ints
|
|
// - Must represent valid 64-bit pointers
|
|
// - Is symbolic test should be FAST (two arithmetic instructions is too
|
|
// much).
|
|
// This code being a hotpath is based on Strobelight profiles of
|
|
// is_heap_allocated(). FB only: https://fburl.com/strobelight/5l50ncxd
|
|
// (you will need to change the time window).
|
|
//
|
|
// So, the scheme is to reserve large negative numbers (assuming
|
|
// two's complement):
|
|
//
|
|
// - 0b0.... means we are a positive int
|
|
// - 0b11... means we are a small negative int
|
|
// - 0b10... means we are are a pointer. This means that
|
|
// [-2^63, -2^62-1] are not representable as ints.
|
|
// We don't actually need all of this space as on x86_64
|
|
// as the top 16bits aren't used for anything
|
|
static constexpr uint64_t MASK = 1ULL << 63 | 1ULL << 62 | 1ULL << 61;
|
|
static constexpr uint64_t IS_SYM = 1ULL << 63 | 1ULL << 61;
|
|
// We must manually translate the bit pattern test into a greater
|
|
// than test because compiler doesn't figure it out:
|
|
// https://godbolt.org/z/356aferaW
|
|
static constexpr int64_t MAX_UNREPRESENTABLE_INT =
|
|
-1LL & static_cast<int64_t>(~(1ULL << 62));
|
|
int64_t data_;
|
|
};
|
|
|
|
/// Sum of a list of SymInt; accumulates into the c10::SymInt expression
|
|
template <
|
|
typename C,
|
|
typename std::enable_if_t<
|
|
std::is_same_v<typename C::value_type, c10::SymInt>,
|
|
int> = 0>
|
|
inline c10::SymInt multiply_integers(const C& container) {
|
|
return std::accumulate(
|
|
container.begin(),
|
|
container.end(),
|
|
c10::SymInt(1),
|
|
[](const c10::SymInt& a, const c10::SymInt& b) { return a * b; });
|
|
}
|
|
|
|
template <
|
|
typename Iter,
|
|
typename = std::enable_if_t<std::is_same_v<
|
|
typename std::iterator_traits<Iter>::value_type,
|
|
c10::SymInt>>>
|
|
inline c10::SymInt multiply_integers(Iter begin, Iter end) {
|
|
return std::accumulate(
|
|
begin,
|
|
end,
|
|
c10::SymInt(1),
|
|
[](const c10::SymInt& a, const c10::SymInt& b) { return a * b; });
|
|
}
|
|
|
|
#define DECLARE_SYMINT_OP_INTONLY(scalar_t, RetTy) \
|
|
C10_API RetTy operator%(const SymInt& a, scalar_t b); \
|
|
C10_API RetTy operator%(scalar_t a, const SymInt& b);
|
|
|
|
#define DECLARE_SYMINT_OP(scalar_t, RetTy) \
|
|
C10_API RetTy operator+(const SymInt& a, scalar_t b); \
|
|
C10_API RetTy operator-(const SymInt& a, scalar_t b); \
|
|
C10_API RetTy operator*(const SymInt& a, scalar_t b); \
|
|
C10_API RetTy operator/(const SymInt& a, scalar_t b); \
|
|
C10_API RetTy operator+(scalar_t a, const SymInt& b); \
|
|
C10_API RetTy operator-(scalar_t a, const SymInt& b); \
|
|
C10_API RetTy operator*(scalar_t a, const SymInt& b); \
|
|
C10_API RetTy operator/(scalar_t a, const SymInt& b); \
|
|
C10_API bool operator==(const SymInt& a, scalar_t b); \
|
|
C10_API bool operator!=(const SymInt& a, scalar_t b); \
|
|
C10_API bool operator<(const SymInt& a, scalar_t b); \
|
|
C10_API bool operator<=(const SymInt& a, scalar_t b); \
|
|
C10_API bool operator>(const SymInt& a, scalar_t b); \
|
|
C10_API bool operator>=(const SymInt& a, scalar_t b); \
|
|
C10_API bool operator==(scalar_t a, const SymInt& b); \
|
|
C10_API bool operator!=(scalar_t a, const SymInt& b); \
|
|
C10_API bool operator<(scalar_t a, const SymInt& b); \
|
|
C10_API bool operator<=(scalar_t a, const SymInt& b); \
|
|
C10_API bool operator>(scalar_t a, const SymInt& b); \
|
|
C10_API bool operator>=(scalar_t a, const SymInt& b);
|
|
|
|
DECLARE_SYMINT_OP_INTONLY(int64_t, SymInt)
|
|
DECLARE_SYMINT_OP_INTONLY(int32_t, SymInt)
|
|
DECLARE_SYMINT_OP_INTONLY(uint64_t, SymInt)
|
|
DECLARE_SYMINT_OP_INTONLY(uint32_t, SymInt)
|
|
DECLARE_SYMINT_OP(int64_t, SymInt)
|
|
DECLARE_SYMINT_OP(int32_t, SymInt) // make sure constants work
|
|
DECLARE_SYMINT_OP(uint64_t, SymInt)
|
|
DECLARE_SYMINT_OP(uint32_t, SymInt)
|
|
DECLARE_SYMINT_OP(double, SymFloat)
|
|
DECLARE_SYMINT_OP(float, SymFloat) // just for completeness
|
|
|
|
// On OSX size_t is different than uint64_t so we have to
|
|
// define it separately
|
|
#if defined(__APPLE__)
|
|
DECLARE_SYMINT_OP_INTONLY(size_t, SymInt)
|
|
DECLARE_SYMINT_OP(size_t, SymInt)
|
|
#endif
|
|
|
|
#undef DECLARE_SYMINT_OP
|
|
|
|
C10_API std::ostream& operator<<(std::ostream& os, const SymInt& s);
|
|
C10_API SymInt operator-(const SymInt& s);
|
|
|
|
inline bool sym_eq(int64_t a, int64_t b) {
|
|
return a == b;
|
|
}
|
|
|
|
inline SymBool sym_eq(const SymInt& a, const SymInt& b) {
|
|
return a.sym_eq(b);
|
|
}
|
|
|
|
inline bool sym_ne(int64_t a, int64_t b) {
|
|
return a != b;
|
|
}
|
|
|
|
inline SymBool sym_ne(const SymInt& a, const SymInt& b) {
|
|
return a.sym_ne(b);
|
|
}
|
|
|
|
inline bool sym_lt(int64_t a, int64_t b) {
|
|
return a < b;
|
|
}
|
|
|
|
inline SymBool sym_lt(const SymInt& a, const SymInt& b) {
|
|
return a.sym_lt(b);
|
|
}
|
|
|
|
inline bool sym_le(int64_t a, int64_t b) {
|
|
return a <= b;
|
|
}
|
|
|
|
inline SymBool sym_le(const SymInt& a, const SymInt& b) {
|
|
return a.sym_le(b);
|
|
}
|
|
|
|
inline bool sym_gt(int64_t a, int64_t b) {
|
|
return a > b;
|
|
}
|
|
|
|
inline SymBool sym_gt(const SymInt& a, const SymInt& b) {
|
|
return a.sym_gt(b);
|
|
}
|
|
|
|
inline bool sym_ge(int64_t a, int64_t b) {
|
|
return a >= b;
|
|
}
|
|
|
|
inline SymBool sym_ge(const SymInt& a, const SymInt& b) {
|
|
return a.sym_ge(b);
|
|
}
|
|
|
|
} // namespace c10
|