mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Support large negative SymInt (#99157)
The strategy is that we will heap allocate a LargeNegativeIntSymNodeImpl whenever we have a large negative int, so that we can keep the old `is_symbolic` test (now called `is_heap_allocated`) on SymInt. Whenever we need to do something with these ints, though, we convert them back into a plain `int64_t` (and then, e.g., wrap it in whatever user specificed SymNodeImpl they need.) We cannot wrap directly in the user specified SymNodeImpl as we generally do not know what the "tracing context" is from C++. We expect large negative ints to be rare, so we don't apply optimizations like singleton-ifying INT_MIN. Here's the order to review: * c10/core/SymInt.h and cpp * `is_symbolic` renamed to `is_heap_allocated` as I needed to audit all use sites: the old `is_symbolic` test would return true for large negative int, but it would be wrong to then try to dispatch on the LargeNegativeIntSymNodeImpl which supports very few operations. In this file, I had to update expect_int, * If you pass in a large negative integer, we instead heap allocate it in `promote_to_negative`. The function is written in a funny way to keep compact constructor code for SymInt (the heap allocation happens out of line) * clone is now moved out-of-line * New method maybe_as_int which will give you a constant int if it is possible, either because it's stored inline or in LargeNegativeIntSymNodeImpl. This is the preferred replacement for previous use of is_symbolic() and then as_int_unchecked(). * Rename toSymNodeImpl to toSymNode, which is more correct (since it returns a SymNode) * Complete rewrite of `normalize_symints.cpp` to use new `maybe_as_int`. Cannot easily use the old code structure, so it's now done doing a macro and typing out each case manually (it's actually not that bad.) * Reimplementations of all the unary operators by hand to use `maybe_as_int`, relatively simple. * c10/core/LargeNegativeIntSymNodeImpl.h - Just stores a int64_t value, but it has to be big and negative. Most methods are not implemented, since we will rewrap the large negative int in the real SymNodeImpl subclass before doing operations with it * The rest of the files are just rewriting code to use `maybe_as_int`. There is a nontrivial comment in c10/core/SymIntArrayRef.h Very minor test adjustment in c10/test/core/SymInt_test.cpp . Plan to exercise this properly in next PR. Companion XLA PR: https://github.com/pytorch/xla/pull/4882 Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/99157 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
5c062e8bb4
commit
756a86d52c
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
@ -1 +1 @@
|
|||||||
27fc8b28ee0a6dc5a223555b980be9fad4c697da
|
f235d4da06905b35d75879a0a9bc3034ab7385ac
|
||||||
|
@ -584,12 +584,12 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
IValue(c10::SymInt i) {
|
IValue(c10::SymInt i) {
|
||||||
if (i.is_symbolic()) {
|
if (auto mi = i.maybe_as_int()) {
|
||||||
tag = Tag::SymInt;
|
|
||||||
payload.u.as_intrusive_ptr = i.toSymNodeImpl().release();
|
|
||||||
} else {
|
|
||||||
tag = Tag::Int;
|
tag = Tag::Int;
|
||||||
payload.u.as_int = i.as_int_unchecked();
|
payload.u.as_int = *mi;
|
||||||
|
} else {
|
||||||
|
tag = Tag::SymInt;
|
||||||
|
payload.u.as_intrusive_ptr = i.toSymNode().release();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -855,7 +855,7 @@ public:
|
|||||||
// for both SymFloat and double
|
// for both SymFloat and double
|
||||||
if (s.isSymInt()) {
|
if (s.isSymInt()) {
|
||||||
tag = Tag::SymInt;
|
tag = Tag::SymInt;
|
||||||
payload.u.as_intrusive_ptr = s.toSymInt().toSymNodeImpl().release();
|
payload.u.as_intrusive_ptr = s.toSymInt().toSymNode().release();
|
||||||
} else if (s.isSymFloat()) {
|
} else if (s.isSymFloat()) {
|
||||||
tag = Tag::SymFloat;
|
tag = Tag::SymFloat;
|
||||||
payload.u.as_intrusive_ptr = s.toSymFloat().toSymNodeImpl().release();
|
payload.u.as_intrusive_ptr = s.toSymFloat().toSymNodeImpl().release();
|
||||||
|
50
c10/core/LargeNegativeIntSymNodeImpl.h
Normal file
50
c10/core/LargeNegativeIntSymNodeImpl.h
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
#include <c10/core/SymNodeImpl.h>
|
||||||
|
|
||||||
|
namespace c10 {
|
||||||
|
|
||||||
|
// Represents an otherwise unrepresentable large negative integer constant.
|
||||||
|
// Unlike other SymNodeImpl, this cannot be "dispatched" conventionally,
|
||||||
|
// as it typically needs to defer to another SymNodeImpl
|
||||||
|
class C10_API LargeNegativeIntSymNodeImpl : public SymNodeImpl {
|
||||||
|
public:
|
||||||
|
LargeNegativeIntSymNodeImpl(int64_t val) : val_(val) {}
|
||||||
|
|
||||||
|
bool is_int() override {
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
bool is_bool() override {
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
bool is_float() override {
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
int64_t guard_int(const char* file, int64_t line) override {
|
||||||
|
return val_;
|
||||||
|
};
|
||||||
|
bool guard_bool(const char* file, int64_t line) override {
|
||||||
|
TORCH_CHECK(false, "not a bool");
|
||||||
|
};
|
||||||
|
double guard_float(const char* file, int64_t line) override {
|
||||||
|
TORCH_CHECK(false, "not a float");
|
||||||
|
};
|
||||||
|
int64_t int_() override {
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
bool bool_() override {
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
bool has_hint() override {
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
std::string str() override {
|
||||||
|
return std::to_string(val_);
|
||||||
|
};
|
||||||
|
int64_t large_negative_int() override {
|
||||||
|
return val_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int64_t val_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace c10
|
@ -276,12 +276,12 @@ class C10_API Scalar {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Scalar(c10::SymInt si) {
|
Scalar(c10::SymInt si) {
|
||||||
if (si.is_symbolic()) {
|
if (auto m = si.maybe_as_int()) {
|
||||||
|
tag = Tag::HAS_i;
|
||||||
|
v.i = *m;
|
||||||
|
} else {
|
||||||
tag = Tag::HAS_si;
|
tag = Tag::HAS_si;
|
||||||
v.p = std::move(si).release();
|
v.p = std::move(si).release();
|
||||||
} else {
|
|
||||||
tag = Tag::HAS_i;
|
|
||||||
v.i = si.as_int_unchecked();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -44,7 +44,7 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
|
|||||||
bool resizable)
|
bool resizable)
|
||||||
: data_ptr_(std::move(data_ptr)),
|
: data_ptr_(std::move(data_ptr)),
|
||||||
size_bytes_(std::move(size_bytes)),
|
size_bytes_(std::move(size_bytes)),
|
||||||
size_bytes_is_symbolic_(size_bytes_.is_symbolic()),
|
size_bytes_is_heap_allocated_(size_bytes_.is_heap_allocated()),
|
||||||
resizable_(resizable),
|
resizable_(resizable),
|
||||||
received_cuda_(false),
|
received_cuda_(false),
|
||||||
allocator_(allocator) {
|
allocator_(allocator) {
|
||||||
@ -62,7 +62,7 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
|
|||||||
: StorageImpl(
|
: StorageImpl(
|
||||||
use_byte_size_t(),
|
use_byte_size_t(),
|
||||||
size_bytes,
|
size_bytes,
|
||||||
size_bytes.is_symbolic()
|
size_bytes.is_heap_allocated()
|
||||||
? allocator->allocate(0)
|
? allocator->allocate(0)
|
||||||
: allocator->allocate(size_bytes.as_int_unchecked()),
|
: allocator->allocate(size_bytes.as_int_unchecked()),
|
||||||
allocator,
|
allocator,
|
||||||
@ -78,7 +78,7 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
|
|||||||
void reset() {
|
void reset() {
|
||||||
data_ptr_.clear();
|
data_ptr_.clear();
|
||||||
size_bytes_ = 0;
|
size_bytes_ = 0;
|
||||||
size_bytes_is_symbolic_ = false;
|
size_bytes_is_heap_allocated_ = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Destructor doesn't call release_resources because it's
|
// Destructor doesn't call release_resources because it's
|
||||||
@ -88,7 +88,8 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
|
|||||||
}
|
}
|
||||||
|
|
||||||
size_t nbytes() const {
|
size_t nbytes() const {
|
||||||
TORCH_CHECK(!size_bytes_is_symbolic_);
|
// OK to do this instead of maybe_as_int as nbytes is guaranteed positive
|
||||||
|
TORCH_CHECK(!size_bytes_is_heap_allocated_);
|
||||||
return size_bytes_.as_int_unchecked();
|
return size_bytes_.as_int_unchecked();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -99,7 +100,7 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
|
|||||||
// TODO: remove later
|
// TODO: remove later
|
||||||
void set_nbytes(size_t size_bytes) {
|
void set_nbytes(size_t size_bytes) {
|
||||||
size_bytes_ = size_bytes;
|
size_bytes_ = size_bytes;
|
||||||
size_bytes_is_symbolic_ = false;
|
size_bytes_is_heap_allocated_ = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
void set_nbytes(c10::SymInt size_bytes) {
|
void set_nbytes(c10::SymInt size_bytes) {
|
||||||
@ -188,7 +189,7 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
|
|||||||
size_t size_bytes) {
|
size_t size_bytes) {
|
||||||
data_ptr_ = std::move(data_ptr);
|
data_ptr_ = std::move(data_ptr);
|
||||||
size_bytes_ = size_bytes;
|
size_bytes_ = size_bytes;
|
||||||
size_bytes_is_symbolic_ = false;
|
size_bytes_is_heap_allocated_ = false;
|
||||||
allocator_ = nullptr;
|
allocator_ = nullptr;
|
||||||
resizable_ = false;
|
resizable_ = false;
|
||||||
}
|
}
|
||||||
@ -206,7 +207,7 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
|
|||||||
private:
|
private:
|
||||||
DataPtr data_ptr_;
|
DataPtr data_ptr_;
|
||||||
SymInt size_bytes_;
|
SymInt size_bytes_;
|
||||||
bool size_bytes_is_symbolic_;
|
bool size_bytes_is_heap_allocated_;
|
||||||
bool resizable_;
|
bool resizable_;
|
||||||
// Identifies that Storage was received from another process and doesn't have
|
// Identifies that Storage was received from another process and doesn't have
|
||||||
// local to process cuda memory allocation
|
// local to process cuda memory allocation
|
||||||
|
@ -1,44 +1,31 @@
|
|||||||
|
#include <c10/core/LargeNegativeIntSymNodeImpl.h>
|
||||||
#include <c10/core/SymFloat.h>
|
#include <c10/core/SymFloat.h>
|
||||||
#include <c10/core/SymInt.h>
|
#include <c10/core/SymInt.h>
|
||||||
#include <c10/core/SymNodeImpl.h>
|
#include <c10/core/SymNodeImpl.h>
|
||||||
|
#include <c10/util/intrusive_ptr.h>
|
||||||
#include <array>
|
#include <array>
|
||||||
|
#include <functional>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
namespace c10 {
|
namespace c10 {
|
||||||
|
|
||||||
static std::array<SymNode, 2> normalize_symints(
|
// Precondition: data_ has a large negative number that should be
|
||||||
const SymInt& a_,
|
// treated as a constant. It is NOT a valid pointer. In other words,
|
||||||
const SymInt& b_) {
|
// SymInt has temporarily violated invariants
|
||||||
SymNode a, b;
|
// Postcondition: invariants on SymInt are fixed
|
||||||
if (a_.is_symbolic())
|
void SymInt::promote_to_negative() {
|
||||||
a = a_.toSymNodeImpl();
|
auto s =
|
||||||
if (b_.is_symbolic())
|
SymInt(SymNode(c10::make_intrusive<LargeNegativeIntSymNodeImpl>(data_)));
|
||||||
b = b_.toSymNodeImpl();
|
// Similar to move operator=, but do NOT release data_
|
||||||
|
data_ = s.data_;
|
||||||
SymNodeImpl* common = a ? a.get() : b.get();
|
s.data_ = 0;
|
||||||
// TODO: technically we need to check that the classes match
|
|
||||||
if (!a) {
|
|
||||||
a = common->wrap_int(a_.as_int_unchecked());
|
|
||||||
}
|
|
||||||
if (!b) {
|
|
||||||
b = common->wrap_int(b_.as_int_unchecked());
|
|
||||||
}
|
|
||||||
return {std::move(a), std::move(b)};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
SymNode SymInt::toSymNodeImpl() const {
|
SymNode SymInt::toSymNode() const {
|
||||||
TORCH_CHECK(is_symbolic());
|
TORCH_CHECK(is_heap_allocated());
|
||||||
return SymNode::reclaim_copy(toSymNodeImplUnowned());
|
return SymNode::reclaim_copy(toSymNodeImplUnowned());
|
||||||
}
|
}
|
||||||
|
|
||||||
SymNode SymInt::wrap_node(const SymNode& base) const {
|
|
||||||
if (is_symbolic()) {
|
|
||||||
return toSymNodeImpl();
|
|
||||||
} else {
|
|
||||||
return base->wrap_int(as_int_unchecked());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
SymInt::SymInt(SymNode sin_sp) {
|
SymInt::SymInt(SymNode sin_sp) {
|
||||||
TORCH_CHECK(sin_sp->is_int());
|
TORCH_CHECK(sin_sp->is_int());
|
||||||
auto ptr = static_cast<uint64_t>(
|
auto ptr = static_cast<uint64_t>(
|
||||||
@ -47,129 +34,86 @@ SymInt::SymInt(SymNode sin_sp) {
|
|||||||
data_ = static_cast<int64_t>(rep);
|
data_ = static_cast<int64_t>(rep);
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t SymInt::guard_int(const char* file, int64_t line) const {
|
|
||||||
if (!is_symbolic()) {
|
|
||||||
return data_;
|
|
||||||
}
|
|
||||||
SymNode a = toSymNodeImpl();
|
|
||||||
return a->guard_int(file, line);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool SymInt::has_hint() const {
|
bool SymInt::has_hint() const {
|
||||||
if (!is_symbolic()) {
|
if (!is_heap_allocated()) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return toSymNodeImpl()->has_hint();
|
return toSymNodeImplUnowned()->has_hint();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define DEFINE_BINARY(API, OP, METHOD, RET) \
|
||||||
|
RET SymInt::API(const SymInt& sci) const { \
|
||||||
|
if (auto ma = maybe_as_int()) { \
|
||||||
|
if (auto mb = sci.maybe_as_int()) { \
|
||||||
|
return RET(OP(*ma, *mb)); \
|
||||||
|
} else { \
|
||||||
|
auto b = sci.toSymNode(); \
|
||||||
|
return RET(b->wrap_int(*ma)->METHOD(b)); \
|
||||||
|
} \
|
||||||
|
} else { \
|
||||||
|
if (auto mb = sci.maybe_as_int()) { \
|
||||||
|
auto a = toSymNodeImplUnowned(); \
|
||||||
|
return RET(a->METHOD(a->wrap_int(*mb))); \
|
||||||
|
} else { \
|
||||||
|
return RET(toSymNodeImplUnowned()->METHOD(sci.toSymNode())); \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
DEFINE_BINARY(operator+, std::plus<>(), add, SymInt)
|
||||||
|
DEFINE_BINARY(operator-, std::minus<>(), sub, SymInt)
|
||||||
|
DEFINE_BINARY(operator*, std::multiplies<>(), mul, SymInt)
|
||||||
|
DEFINE_BINARY(operator/, std::divides<>(), floordiv, SymInt)
|
||||||
|
DEFINE_BINARY(operator%, std::modulus<>(), mod, SymInt)
|
||||||
|
DEFINE_BINARY(sym_eq, std::equal_to<>(), eq, SymBool)
|
||||||
|
DEFINE_BINARY(sym_ne, std::not_equal_to<>(), ne, SymBool)
|
||||||
|
DEFINE_BINARY(sym_lt, std::less<>(), lt, SymBool)
|
||||||
|
DEFINE_BINARY(sym_le, std::less_equal<>(), le, SymBool)
|
||||||
|
DEFINE_BINARY(sym_gt, std::greater<>(), gt, SymBool)
|
||||||
|
DEFINE_BINARY(sym_ge, std::greater_equal<>(), ge, SymBool)
|
||||||
|
DEFINE_BINARY(min, std::min, sym_min, SymInt)
|
||||||
|
DEFINE_BINARY(max, std::max, sym_max, SymInt)
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
SymInt::operator SymFloat() const {
|
SymInt::operator SymFloat() const {
|
||||||
if (!is_symbolic()) {
|
if (auto ma = maybe_as_int()) {
|
||||||
return SymFloat(double(data_));
|
return SymFloat(double(*ma));
|
||||||
|
} else {
|
||||||
|
return SymFloat(toSymNodeImplUnowned()->sym_float());
|
||||||
}
|
}
|
||||||
return SymFloat(toSymNodeImpl()->sym_float());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
SymInt SymInt::operator+(const SymInt& sci) const {
|
SymNode SymInt::wrap_node(const SymNode& base) const {
|
||||||
if (!is_symbolic() && !sci.is_symbolic()) {
|
if (auto ma = maybe_as_int()) {
|
||||||
return SymInt(data_ + sci.data_);
|
return base->wrap_int(*ma);
|
||||||
|
} else {
|
||||||
|
return toSymNode();
|
||||||
}
|
}
|
||||||
auto res = normalize_symints(*this, sci);
|
|
||||||
return SymInt(res[0]->add(res[1]));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
SymInt SymInt::operator-(const SymInt& sci) const {
|
SymInt SymInt::clone() const {
|
||||||
if (!is_symbolic() && !sci.is_symbolic()) {
|
if (auto ma = maybe_as_int()) {
|
||||||
return SymInt(data_ - sci.data_);
|
return SymInt(*ma);
|
||||||
|
} else {
|
||||||
|
return SymInt(toSymNodeImplUnowned()->clone());
|
||||||
}
|
}
|
||||||
auto res = normalize_symints(*this, sci);
|
|
||||||
return SymInt(res[0]->sub(res[1]));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
SymInt SymInt::operator*(const SymInt& sci) const {
|
int64_t SymInt::guard_int(const char* file, int64_t line) const {
|
||||||
if (!is_symbolic() && !sci.is_symbolic()) {
|
if (auto ma = maybe_as_int()) {
|
||||||
return SymInt(data_ * sci.data_);
|
return *ma;
|
||||||
|
} else {
|
||||||
|
return toSymNodeImplUnowned()->guard_int(file, line);
|
||||||
}
|
}
|
||||||
auto res = normalize_symints(*this, sci);
|
|
||||||
return SymInt(res[0]->mul(res[1]));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
SymInt SymInt::operator/(const SymInt& sci) const {
|
SymInt operator-(const SymInt& s) {
|
||||||
if (!is_symbolic() && !sci.is_symbolic()) {
|
if (auto ma = s.maybe_as_int()) {
|
||||||
return SymInt(data_ / sci.data_);
|
return SymInt(-*ma);
|
||||||
|
} else {
|
||||||
|
return SymInt(s.toSymNodeImplUnowned()->neg());
|
||||||
}
|
}
|
||||||
auto res = normalize_symints(*this, sci);
|
|
||||||
return SymInt(res[0]->floordiv(res[1]));
|
|
||||||
}
|
|
||||||
|
|
||||||
SymInt SymInt::operator%(const SymInt& sci) const {
|
|
||||||
if (!is_symbolic() && !sci.is_symbolic()) {
|
|
||||||
return SymInt(data_ % sci.data_);
|
|
||||||
}
|
|
||||||
auto res = normalize_symints(*this, sci);
|
|
||||||
return SymInt(res[0]->mod(res[1]));
|
|
||||||
}
|
|
||||||
|
|
||||||
SymBool SymInt::sym_eq(const SymInt& sci) const {
|
|
||||||
if (!is_symbolic() && !sci.is_symbolic()) {
|
|
||||||
return data_ == sci.data_;
|
|
||||||
}
|
|
||||||
auto res = normalize_symints(*this, sci);
|
|
||||||
return res[0]->eq(res[1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
SymBool SymInt::sym_ne(const SymInt& sci) const {
|
|
||||||
if (!is_symbolic() && !sci.is_symbolic()) {
|
|
||||||
return data_ != sci.data_;
|
|
||||||
}
|
|
||||||
auto res = normalize_symints(*this, sci);
|
|
||||||
return res[0]->ne(res[1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
SymBool SymInt::sym_lt(const SymInt& sci) const {
|
|
||||||
if (!is_symbolic() && !sci.is_symbolic()) {
|
|
||||||
return data_ < sci.data_;
|
|
||||||
}
|
|
||||||
auto res = normalize_symints(*this, sci);
|
|
||||||
return res[0]->lt(res[1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
SymBool SymInt::sym_le(const SymInt& sci) const {
|
|
||||||
if (!is_symbolic() && !sci.is_symbolic()) {
|
|
||||||
return data_ <= sci.data_;
|
|
||||||
}
|
|
||||||
auto res = normalize_symints(*this, sci);
|
|
||||||
return res[0]->le(res[1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
SymBool SymInt::sym_gt(const SymInt& sci) const {
|
|
||||||
if (!is_symbolic() && !sci.is_symbolic()) {
|
|
||||||
return data_ > sci.data_;
|
|
||||||
}
|
|
||||||
auto res = normalize_symints(*this, sci);
|
|
||||||
return res[0]->gt(res[1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
SymBool SymInt::sym_ge(const SymInt& sci) const {
|
|
||||||
if (!is_symbolic() && !sci.is_symbolic()) {
|
|
||||||
return data_ >= sci.data_;
|
|
||||||
}
|
|
||||||
auto res = normalize_symints(*this, sci);
|
|
||||||
return res[0]->ge(res[1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
SymInt SymInt::min(const SymInt& sci) const {
|
|
||||||
if (!is_symbolic() && !sci.is_symbolic()) {
|
|
||||||
return std::min(data_, sci.data_);
|
|
||||||
}
|
|
||||||
auto res = normalize_symints(*this, sci);
|
|
||||||
return SymInt(res[0]->sym_min(res[1]));
|
|
||||||
}
|
|
||||||
SymInt SymInt::max(const SymInt& sci) const {
|
|
||||||
if (!is_symbolic() && !sci.is_symbolic()) {
|
|
||||||
return std::max(data_, sci.data_);
|
|
||||||
}
|
|
||||||
auto res = normalize_symints(*this, sci);
|
|
||||||
return SymInt(res[0]->sym_max(res[1]));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void SymInt::operator*=(const SymInt& sci) {
|
void SymInt::operator*=(const SymInt& sci) {
|
||||||
@ -213,20 +157,12 @@ SymInt SymInt::operator*(int64_t sci) const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::ostream& operator<<(std::ostream& os, const SymInt& s) {
|
std::ostream& operator<<(std::ostream& os, const SymInt& s) {
|
||||||
if (s.is_symbolic()) {
|
if (s.is_heap_allocated()) {
|
||||||
os << s.toSymNodeImpl()->str();
|
os << s.toSymNodeImplUnowned()->str();
|
||||||
} else {
|
} else {
|
||||||
os << s.as_int_unchecked();
|
os << s.as_int_unchecked();
|
||||||
}
|
}
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
SymInt operator-(const SymInt& s) {
|
|
||||||
if (s.is_symbolic()) {
|
|
||||||
return SymInt(s.toSymNodeImpl()->neg());
|
|
||||||
} else {
|
|
||||||
return SymInt(-s.as_int_unchecked());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace c10
|
} // namespace c10
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
#include <c10/core/SymNodeImpl.h>
|
#include <c10/core/SymNodeImpl.h>
|
||||||
#include <c10/macros/Macros.h>
|
#include <c10/macros/Macros.h>
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
|
#include <c10/util/Optional.h>
|
||||||
|
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
@ -32,10 +33,10 @@ class C10_API SymInt {
|
|||||||
};
|
};
|
||||||
|
|
||||||
/*implicit*/ SymInt(int64_t d) : data_(d) {
|
/*implicit*/ SymInt(int64_t d) : data_(d) {
|
||||||
// NB: this relies on exception in constructor inhibiting
|
if (is_heap_allocated()) {
|
||||||
// destructor; otherwise we would attempt to deallocate
|
// Large negative number, heap allocate it
|
||||||
// the garbage data!
|
promote_to_negative();
|
||||||
TORCH_CHECK(!is_symbolic());
|
}
|
||||||
};
|
};
|
||||||
SymInt() : data_(0) {}
|
SymInt() : data_(0) {}
|
||||||
SymInt(SymNode n);
|
SymInt(SymNode n);
|
||||||
@ -49,8 +50,8 @@ class C10_API SymInt {
|
|||||||
// TODO: these implementations are not optimal because they allocate a
|
// TODO: these implementations are not optimal because they allocate a
|
||||||
// temporary and then use the move constructor/assignment
|
// temporary and then use the move constructor/assignment
|
||||||
SymInt(const SymInt& s) : data_(0) {
|
SymInt(const SymInt& s) : data_(0) {
|
||||||
if (s.is_symbolic()) {
|
if (s.is_heap_allocated()) {
|
||||||
*this = SymInt(s.toSymNodeImpl());
|
*this = SymInt(s.toSymNode());
|
||||||
} else {
|
} else {
|
||||||
data_ = s.data_;
|
data_ = s.data_;
|
||||||
}
|
}
|
||||||
@ -61,8 +62,8 @@ class C10_API SymInt {
|
|||||||
|
|
||||||
SymInt& operator=(const SymInt& s) {
|
SymInt& operator=(const SymInt& s) {
|
||||||
if (this != &s) {
|
if (this != &s) {
|
||||||
if (s.is_symbolic()) {
|
if (s.is_heap_allocated()) {
|
||||||
*this = SymInt(s.toSymNodeImpl());
|
*this = SymInt(s.toSymNode());
|
||||||
} else {
|
} else {
|
||||||
data_ = s.data_;
|
data_ = s.data_;
|
||||||
}
|
}
|
||||||
@ -73,21 +74,14 @@ class C10_API SymInt {
|
|||||||
if (this != &s) {
|
if (this != &s) {
|
||||||
release_(); // release the current SymNode if any
|
release_(); // release the current SymNode if any
|
||||||
data_ = s.data_;
|
data_ = s.data_;
|
||||||
if (s.is_symbolic())
|
if (s.is_heap_allocated())
|
||||||
s.data_ = 0;
|
s.data_ = 0;
|
||||||
};
|
};
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
SymInt clone() const {
|
|
||||||
if (is_symbolic()) {
|
|
||||||
return SymInt(toSymNodeImplUnowned()->clone());
|
|
||||||
}
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
SymNodeImpl* toSymNodeImplUnowned() const {
|
SymNodeImpl* toSymNodeImplUnowned() const {
|
||||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(is_symbolic());
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(is_heap_allocated());
|
||||||
uint64_t unextended_bits = static_cast<uint64_t>(data_) & ~MASK;
|
uint64_t unextended_bits = static_cast<uint64_t>(data_) & ~MASK;
|
||||||
uint64_t sign_bit_mask = 1ULL << (62 - 1);
|
uint64_t sign_bit_mask = 1ULL << (62 - 1);
|
||||||
// https://stackoverflow.com/questions/42534749/signed-extension-from-24-bit-to-32-bit-in-c
|
// https://stackoverflow.com/questions/42534749/signed-extension-from-24-bit-to-32-bit-in-c
|
||||||
@ -97,14 +91,14 @@ class C10_API SymInt {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void release_() {
|
void release_() {
|
||||||
if (is_symbolic()) {
|
if (is_heap_allocated()) {
|
||||||
SymNode::reclaim(toSymNodeImplUnowned()); // steal
|
SymNode::reclaim(toSymNodeImplUnowned()); // steal
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
SymNodeImpl* release() && {
|
SymNodeImpl* release() && {
|
||||||
#ifndef C10_MOBILE
|
#ifndef C10_MOBILE
|
||||||
TORCH_INTERNAL_ASSERT(is_symbolic());
|
TORCH_INTERNAL_ASSERT(is_heap_allocated());
|
||||||
auto* r = toSymNodeImplUnowned();
|
auto* r = toSymNodeImplUnowned();
|
||||||
data_ = 0; // transfer ownership
|
data_ = 0; // transfer ownership
|
||||||
return r;
|
return r;
|
||||||
@ -113,8 +107,8 @@ class C10_API SymInt {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only valid if is_symbolic()
|
// Only valid if is_heap_allocated()
|
||||||
SymNode toSymNodeImpl() const;
|
SymNode toSymNode() const;
|
||||||
|
|
||||||
// Guaranteed to return a SymNode, wrapping using base if necessary
|
// Guaranteed to return a SymNode, wrapping using base if necessary
|
||||||
SymNode wrap_node(const SymNode& base) const;
|
SymNode wrap_node(const SymNode& base) const;
|
||||||
@ -128,8 +122,10 @@ class C10_API SymInt {
|
|||||||
// shapes, and you don't have time to fix it immediately, as if we
|
// shapes, and you don't have time to fix it immediately, as if we
|
||||||
// try to trigger the path in C++ you'll appropriately get an error
|
// try to trigger the path in C++ you'll appropriately get an error
|
||||||
int64_t expect_int() const {
|
int64_t expect_int() const {
|
||||||
TORCH_CHECK(!is_symbolic());
|
if (auto r = maybe_as_int()) {
|
||||||
return data_;
|
return *r;
|
||||||
|
}
|
||||||
|
TORCH_CHECK(false, "expected int but got ", *this);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test if we have a hint for this int (e.g., guard_int would work).
|
// Test if we have a hint for this int (e.g., guard_int would work).
|
||||||
@ -150,8 +146,8 @@ class C10_API SymInt {
|
|||||||
|
|
||||||
// N.B. It's important to keep this definition in the header
|
// N.B. It's important to keep this definition in the header
|
||||||
// as we expect if checks to be folded for mobile builds
|
// as we expect if checks to be folded for mobile builds
|
||||||
// where `is_symbolic` is always false and optimize dead code paths
|
// where `is_heap_allocated` is always false and optimize dead code paths
|
||||||
C10_ALWAYS_INLINE bool is_symbolic() const {
|
C10_ALWAYS_INLINE bool is_heap_allocated() const {
|
||||||
#ifdef C10_MOBILE
|
#ifdef C10_MOBILE
|
||||||
return false;
|
return false;
|
||||||
#else
|
#else
|
||||||
@ -168,6 +164,8 @@ class C10_API SymInt {
|
|||||||
void operator+=(const SymInt& sci);
|
void operator+=(const SymInt& sci);
|
||||||
void operator/=(const SymInt& sci);
|
void operator/=(const SymInt& sci);
|
||||||
|
|
||||||
|
SymInt clone() const;
|
||||||
|
|
||||||
SymBool sym_eq(const SymInt&) const;
|
SymBool sym_eq(const SymInt&) const;
|
||||||
SymBool sym_ne(const SymInt&) const;
|
SymBool sym_ne(const SymInt&) const;
|
||||||
SymBool sym_lt(const SymInt&) const;
|
SymBool sym_lt(const SymInt&) const;
|
||||||
@ -207,22 +205,42 @@ class C10_API SymInt {
|
|||||||
|
|
||||||
operator SymFloat() const;
|
operator SymFloat() const;
|
||||||
|
|
||||||
|
// Don't use this. Prefer maybe_as_int instead
|
||||||
int64_t as_int_unchecked() const {
|
int64_t as_int_unchecked() const {
|
||||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!is_symbolic());
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!is_heap_allocated());
|
||||||
return data_;
|
return data_;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return whether the integer is representable as a SymInt.
|
c10::optional<int64_t> maybe_as_int() const {
|
||||||
|
if (!is_heap_allocated()) {
|
||||||
|
return c10::make_optional(data_);
|
||||||
|
}
|
||||||
|
int64_t c = toSymNodeImplUnowned()->large_negative_int();
|
||||||
|
if (c != 0) {
|
||||||
|
return c10::make_optional(c);
|
||||||
|
}
|
||||||
|
return c10::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return whether the integer is directly coercible to a SymInt
|
||||||
|
// without requiring heap allocation. You don't need to use this
|
||||||
|
// to check if you can pass an integer to SymInt; this is guaranteed
|
||||||
|
// to work (it just might heap allocate!)
|
||||||
static bool check_range(int64_t i) {
|
static bool check_range(int64_t i) {
|
||||||
return i > MAX_UNREPRESENTABLE_INT;
|
return i > MAX_UNREPRESENTABLE_INT;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return the min representable integer as a SymInt
|
// Return the min representable integer as a SymInt without
|
||||||
|
// heap allocation. For quantities that count bytes (or larger),
|
||||||
|
// this is still much larger than you need, so you may consider
|
||||||
|
// using this as a more efficient version of MIN_INT
|
||||||
static constexpr int64_t min_representable_int() {
|
static constexpr int64_t min_representable_int() {
|
||||||
return MAX_UNREPRESENTABLE_INT + 1;
|
return MAX_UNREPRESENTABLE_INT + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
void promote_to_negative();
|
||||||
|
|
||||||
// Constraints on the internal representation:
|
// Constraints on the internal representation:
|
||||||
//
|
//
|
||||||
// - Should represent positive and small negative ints
|
// - Should represent positive and small negative ints
|
||||||
@ -231,7 +249,7 @@ class C10_API SymInt {
|
|||||||
// - Is symbolic test should be FAST (two arithmetic instructions is too
|
// - Is symbolic test should be FAST (two arithmetic instructions is too
|
||||||
// much).
|
// much).
|
||||||
// This code being a hotpath is based on Strobelight profiles of
|
// This code being a hotpath is based on Strobelight profiles of
|
||||||
// is_symbolic(). FB only: https://fburl.com/strobelight/5l50ncxd
|
// is_heap_allocated(). FB only: https://fburl.com/strobelight/5l50ncxd
|
||||||
// (you will need to change the time window).
|
// (you will need to change the time window).
|
||||||
//
|
//
|
||||||
// So, the scheme is to reserve large negative numbers (assuming
|
// So, the scheme is to reserve large negative numbers (assuming
|
||||||
|
@ -12,10 +12,16 @@ inline at::IntArrayRef asIntArrayRefUnchecked(c10::SymIntArrayRef ar) {
|
|||||||
return IntArrayRef(reinterpret_cast<const int64_t*>(ar.data()), ar.size());
|
return IntArrayRef(reinterpret_cast<const int64_t*>(ar.data()), ar.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: a SymIntArrayRef containing a heap allocated large negative integer
|
||||||
|
// can actually technically be converted to an IntArrayRef... but not with
|
||||||
|
// the non-owning API we have here. We can't reinterpet cast; we have to
|
||||||
|
// allocate another buffer and write the integers into it. If you need it,
|
||||||
|
// we can do it. But I don't think you need it.
|
||||||
|
|
||||||
inline c10::optional<at::IntArrayRef> asIntArrayRefSlowOpt(
|
inline c10::optional<at::IntArrayRef> asIntArrayRefSlowOpt(
|
||||||
c10::SymIntArrayRef ar) {
|
c10::SymIntArrayRef ar) {
|
||||||
for (const c10::SymInt& sci : ar) {
|
for (const c10::SymInt& sci : ar) {
|
||||||
if (sci.is_symbolic()) {
|
if (sci.is_heap_allocated()) {
|
||||||
return c10::nullopt;
|
return c10::nullopt;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -29,7 +35,7 @@ inline at::IntArrayRef asIntArrayRefSlow(
|
|||||||
int64_t line) {
|
int64_t line) {
|
||||||
for (const c10::SymInt& sci : ar) {
|
for (const c10::SymInt& sci : ar) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
!sci.is_symbolic(),
|
!sci.is_heap_allocated(),
|
||||||
file,
|
file,
|
||||||
":",
|
":",
|
||||||
line,
|
line,
|
||||||
|
@ -14,6 +14,7 @@ using SymNode = c10::intrusive_ptr<SymNodeImpl>;
|
|||||||
// When you add a method, you also need to edit
|
// When you add a method, you also need to edit
|
||||||
// torch/csrc/jit/python/init.cpp
|
// torch/csrc/jit/python/init.cpp
|
||||||
// torch/csrc/utils/python_symnode.h
|
// torch/csrc/utils/python_symnode.h
|
||||||
|
// c10/core/ConstantSymNodeImpl.h
|
||||||
class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
|
class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
|
||||||
public:
|
public:
|
||||||
~SymNodeImpl() override = default;
|
~SymNodeImpl() override = default;
|
||||||
@ -163,6 +164,9 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
|
|||||||
virtual std::string str() {
|
virtual std::string str() {
|
||||||
TORCH_CHECK(false, "NYI");
|
TORCH_CHECK(false, "NYI");
|
||||||
};
|
};
|
||||||
|
virtual int64_t large_negative_int() {
|
||||||
|
return 0; // not a large negative int!
|
||||||
|
}
|
||||||
std::ostream& operator<<(std::ostream& os) {
|
std::ostream& operator<<(std::ostream& os) {
|
||||||
os << str();
|
os << str();
|
||||||
return os;
|
return os;
|
||||||
|
@ -232,20 +232,22 @@ normalize_sym_sizes_strides(SymIntArrayRef sizes, SymIntArrayRef strides) {
|
|||||||
// Look for a SymNode to dispatch on
|
// Look for a SymNode to dispatch on
|
||||||
SymNode base;
|
SymNode base;
|
||||||
bool all_hinted = true;
|
bool all_hinted = true;
|
||||||
|
// NB: sizes/strides guaranteed to be positive, so only need
|
||||||
|
// is_heap_allocated
|
||||||
for (const auto& s : sizes) {
|
for (const auto& s : sizes) {
|
||||||
if (all_hinted && !s.has_hint()) {
|
if (all_hinted && !s.has_hint()) {
|
||||||
all_hinted = false;
|
all_hinted = false;
|
||||||
}
|
}
|
||||||
if (!base && s.is_symbolic()) {
|
if (!base && s.is_heap_allocated()) {
|
||||||
base = s.toSymNodeImpl();
|
base = s.toSymNode();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (const auto& s : strides) {
|
for (const auto& s : strides) {
|
||||||
if (all_hinted && !s.has_hint()) {
|
if (all_hinted && !s.has_hint()) {
|
||||||
all_hinted = false;
|
all_hinted = false;
|
||||||
}
|
}
|
||||||
if (!base && s.is_symbolic()) {
|
if (!base && s.is_heap_allocated()) {
|
||||||
base = s.toSymNodeImpl();
|
base = s.toSymNode();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!base || all_hinted) {
|
if (!base || all_hinted) {
|
||||||
@ -1125,7 +1127,8 @@ void TensorImpl::set_sizes_and_strides(
|
|||||||
auto int_sizes = asIntArrayRefSlowOpt(sizes);
|
auto int_sizes = asIntArrayRefSlowOpt(sizes);
|
||||||
auto int_strides = asIntArrayRefSlowOpt(strides);
|
auto int_strides = asIntArrayRefSlowOpt(strides);
|
||||||
if (int_sizes && int_strides &&
|
if (int_sizes && int_strides &&
|
||||||
(!storage_offset.has_value() || !storage_offset->is_symbolic()) &&
|
// NB: storage_offset guaranteed to be positive
|
||||||
|
(!storage_offset.has_value() || !storage_offset->is_heap_allocated()) &&
|
||||||
!has_symbolic_sizes_strides_) {
|
!has_symbolic_sizes_strides_) {
|
||||||
set_sizes_and_strides(*int_sizes, *int_strides);
|
set_sizes_and_strides(*int_sizes, *int_strides);
|
||||||
if (storage_offset.has_value())
|
if (storage_offset.has_value())
|
||||||
|
@ -6,18 +6,16 @@
|
|||||||
using namespace c10;
|
using namespace c10;
|
||||||
#ifndef C10_MOBILE
|
#ifndef C10_MOBILE
|
||||||
void check(int64_t value) {
|
void check(int64_t value) {
|
||||||
EXPECT_TRUE(SymInt::check_range(value));
|
|
||||||
const auto i = SymInt(value);
|
const auto i = SymInt(value);
|
||||||
EXPECT_FALSE(i.is_symbolic());
|
EXPECT_EQ(i.maybe_as_int(), c10::make_optional(value));
|
||||||
EXPECT_EQ(i.as_int_unchecked(), value);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SymIntTest, ConcreteInts) {
|
TEST(SymIntTest, ConcreteInts) {
|
||||||
check(INT64_MAX);
|
check(INT64_MAX);
|
||||||
check(0);
|
check(0);
|
||||||
check(-1);
|
check(-1);
|
||||||
// This is 2^62, which is the most negative number we can support.
|
|
||||||
check(-4611686018427387904LL);
|
check(-4611686018427387904LL);
|
||||||
|
check(INT64_MIN);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SymIntTest, CheckRange) {
|
TEST(SymIntTest, CheckRange) {
|
||||||
|
@ -50,7 +50,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
m.def("get_storage", []() { return random_tensor().storage(); });
|
m.def("get_storage", []() { return random_tensor().storage(); });
|
||||||
m.def("get_symfloat", []() { return c10::SymFloat(1.0); });
|
m.def("get_symfloat", []() { return c10::SymFloat(1.0); });
|
||||||
m.def("get_symint", []() { return c10::SymInt(1); });
|
m.def("get_symint", []() { return c10::SymInt(1); });
|
||||||
m.def("get_symint_symbolic", []() { return c10::SymInt(c10::SymInt::UNCHECKED, INT64_MIN); });
|
|
||||||
m.def("get_symintarrayref", []() { return at::SymIntArrayRef({1, 2, 3}); });
|
m.def("get_symintarrayref", []() { return at::SymIntArrayRef({1, 2, 3}); });
|
||||||
m.def("get_tensor", []() { return random_tensor(); });
|
m.def("get_tensor", []() { return random_tensor(); });
|
||||||
}
|
}
|
||||||
|
@ -195,7 +195,8 @@ class TestPybindTypeCasters(common.TestCase):
|
|||||||
assert len(union_type) == 1
|
assert len(union_type) == 1
|
||||||
union_type = union_type.pop()
|
union_type = union_type.pop()
|
||||||
self.assertIs(Union, get_origin(union_type))
|
self.assertIs(Union, get_origin(union_type))
|
||||||
expected_types = set(get_args(union_type))
|
# SymInt is inconvenient to test, so don't require it
|
||||||
|
expected_types = set(get_args(union_type)) - {torch.SymInt}
|
||||||
for func in funcs:
|
for func in funcs:
|
||||||
val = func()
|
val = func()
|
||||||
for tp in expected_types:
|
for tp in expected_types:
|
||||||
@ -219,7 +220,7 @@ class TestPybindTypeCasters(common.TestCase):
|
|||||||
cpp_extension.get_tensor,
|
cpp_extension.get_tensor,
|
||||||
]
|
]
|
||||||
union_functions = [
|
union_functions = [
|
||||||
[cpp_extension.get_symint, cpp_extension.get_symint_symbolic],
|
[cpp_extension.get_symint],
|
||||||
]
|
]
|
||||||
for func in functions:
|
for func in functions:
|
||||||
with self.subTest(msg=f"check {func.__name__}"):
|
with self.subTest(msg=f"check {func.__name__}"):
|
||||||
|
@ -308,11 +308,11 @@ GETTER_BODY_ARRAYREF_SYMINT = """\
|
|||||||
PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
|
PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
|
||||||
for (auto i : c10::irange(prop.size())) {
|
for (auto i : c10::irange(prop.size())) {
|
||||||
auto si = prop[i];
|
auto si = prop[i];
|
||||||
if (si.is_symbolic()) {
|
if (auto m = si.maybe_as_int()) {
|
||||||
|
PyTuple_SetItem(tup, (Py_ssize_t) i, PyLong_FromUnsignedLong(*m));
|
||||||
|
} else {
|
||||||
auto py_symint = py::cast(si).release().ptr();
|
auto py_symint = py::cast(si).release().ptr();
|
||||||
PyTuple_SetItem(tup, (Py_ssize_t) i, py_symint);
|
PyTuple_SetItem(tup, (Py_ssize_t) i, py_symint);
|
||||||
} else {
|
|
||||||
PyTuple_SetItem(tup, (Py_ssize_t) i, PyLong_FromUnsignedLong(si.as_int_unchecked()));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return tup;
|
return tup;
|
||||||
@ -331,7 +331,11 @@ return PyLong_FromUnsignedLong((int64_t) prop);
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
GETTER_BODY_SYMINT = """\
|
GETTER_BODY_SYMINT = """\
|
||||||
return prop.is_symbolic() ? py::cast(prop).release().ptr() : PyLong_FromUnsignedLong(prop.as_int_unchecked());
|
if (auto m = prop.maybe_as_int()) {
|
||||||
|
return PyLong_FromUnsignedLong(*m);
|
||||||
|
} else {
|
||||||
|
return py::cast(prop).release().ptr();
|
||||||
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
GETTER_BODY_DOUBLE = """\
|
GETTER_BODY_DOUBLE = """\
|
||||||
|
@ -55,15 +55,7 @@ PyObject* THPSize_NewFromSymSizes(const at::Tensor& self_) {
|
|||||||
|
|
||||||
for (auto i : c10::irange(sym_sizes.size())) {
|
for (auto i : c10::irange(sym_sizes.size())) {
|
||||||
auto si = sym_sizes[i];
|
auto si = sym_sizes[i];
|
||||||
if (si.is_symbolic()) {
|
if (auto m = si.maybe_as_int()) {
|
||||||
TORCH_CHECK(
|
|
||||||
!torch::jit::tracer::isTracing(),
|
|
||||||
"JIT Tracing of SymInts isn't supported");
|
|
||||||
auto py_symint = py::cast(si).release().ptr();
|
|
||||||
if (!py_symint)
|
|
||||||
throw python_error();
|
|
||||||
PyTuple_SET_ITEM(ret.get(), i, py_symint);
|
|
||||||
} else {
|
|
||||||
if (torch::jit::tracer::isTracing()) {
|
if (torch::jit::tracer::isTracing()) {
|
||||||
PyObject* py_size_tensor =
|
PyObject* py_size_tensor =
|
||||||
THPVariable_Wrap(torch::jit::tracer::getSizeOf(self_, i));
|
THPVariable_Wrap(torch::jit::tracer::getSizeOf(self_, i));
|
||||||
@ -71,9 +63,16 @@ PyObject* THPSize_NewFromSymSizes(const at::Tensor& self_) {
|
|||||||
throw python_error();
|
throw python_error();
|
||||||
PyTuple_SET_ITEM(ret.get(), i, py_size_tensor);
|
PyTuple_SET_ITEM(ret.get(), i, py_size_tensor);
|
||||||
} else {
|
} else {
|
||||||
PyTuple_SET_ITEM(
|
PyTuple_SET_ITEM(ret.get(), i, THPUtils_packInt64(*m));
|
||||||
ret.get(), i, THPUtils_packInt64(si.as_int_unchecked()));
|
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(
|
||||||
|
!torch::jit::tracer::isTracing(),
|
||||||
|
"JIT Tracing of SymInts isn't supported");
|
||||||
|
auto py_symint = py::cast(si).release().ptr();
|
||||||
|
if (!py_symint)
|
||||||
|
throw python_error();
|
||||||
|
PyTuple_SET_ITEM(ret.get(), i, py_symint);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return ret.release();
|
return ret.release();
|
||||||
|
@ -88,7 +88,8 @@ void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) {
|
|||||||
// no op, there is nothing to tag
|
// no op, there is nothing to tag
|
||||||
break;
|
break;
|
||||||
case c10::SymIntType::Kind:
|
case c10::SymIntType::Kind:
|
||||||
TORCH_CHECK(!w.value.toSymInt().is_symbolic());
|
// TODO: Can this really show up though? :think:
|
||||||
|
TORCH_CHECK(!w.value.toSymInt().is_heap_allocated());
|
||||||
// no op, there is nothing to tag
|
// no op, there is nothing to tag
|
||||||
break;
|
break;
|
||||||
case c10::SymFloatType::Kind:
|
case c10::SymFloatType::Kind:
|
||||||
|
@ -127,12 +127,14 @@ static inline NodePtr MakeSizeDiv(const Value& a, const Value& b) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
inline Value GetSymIntValue(c10::SymInt a) {
|
inline Value GetSymIntValue(c10::SymInt a) {
|
||||||
return Value(
|
if (auto ma = a.maybe_as_int()) {
|
||||||
a.is_symbolic()
|
return Value(MakeScalar(*ma, at::kLong), 0);
|
||||||
? dynamic_cast<torch::lazy::SymNodeImpl*>(a.toSymNodeImpl().get())
|
} else {
|
||||||
->node_
|
return Value(
|
||||||
: MakeScalar(a.as_int_unchecked(), at::kLong),
|
dynamic_cast<torch::lazy::SymNodeImpl*>(a.toSymNodeImplUnowned())
|
||||||
0);
|
->node_,
|
||||||
|
0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: this should return Value
|
// TODO: this should return Value
|
||||||
|
@ -451,25 +451,24 @@ std::vector<Shape> compute_shape_expand(
|
|||||||
padded_self.end(), self.sizes().begin(), self.sizes().end());
|
padded_self.end(), self.sizes().begin(), self.sizes().end());
|
||||||
std::vector<int64_t> target_size(_sizes.size());
|
std::vector<int64_t> target_size(_sizes.size());
|
||||||
for (const auto idx : c10::irange(_sizes.size())) {
|
for (const auto idx : c10::irange(_sizes.size())) {
|
||||||
if (_sizes[idx].is_symbolic()) {
|
if (auto ma = _sizes[idx].maybe_as_int()) {
|
||||||
c10::SymNode symbolicIntNode = _sizes[idx].toSymNodeImpl();
|
target_size[idx] = *ma;
|
||||||
auto* lazySymNode =
|
if (*ma == -1) {
|
||||||
dynamic_cast<torch::lazy::SymNodeImpl*>(symbolicIntNode.get());
|
// -1 can't be specified for non-existing dimensions
|
||||||
|
TORCH_CHECK(idx >= num_new_dimensions);
|
||||||
|
target_size[idx] = padded_self[idx];
|
||||||
|
} else {
|
||||||
|
target_size[idx] = *ma;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto* lazySymNode = dynamic_cast<torch::lazy::SymNodeImpl*>(
|
||||||
|
_sizes[idx].toSymNodeImplUnowned());
|
||||||
TORCH_INTERNAL_ASSERT(lazySymNode);
|
TORCH_INTERNAL_ASSERT(lazySymNode);
|
||||||
auto size_node = lazySymNode->node_;
|
auto size_node = lazySymNode->node_;
|
||||||
auto static_value =
|
auto static_value =
|
||||||
std::dynamic_pointer_cast<torch::lazy::DimensionNode>(size_node)
|
std::dynamic_pointer_cast<torch::lazy::DimensionNode>(size_node)
|
||||||
->getStaticValue();
|
->getStaticValue();
|
||||||
target_size[idx] = static_value;
|
target_size[idx] = static_value;
|
||||||
} else {
|
|
||||||
target_size[idx] = _sizes[idx].as_int_unchecked();
|
|
||||||
if (_sizes[idx].as_int_unchecked() == -1) {
|
|
||||||
// -1 can't be specified for non-existing dimensions
|
|
||||||
TORCH_CHECK(idx >= num_new_dimensions);
|
|
||||||
target_size[idx] = padded_self[idx];
|
|
||||||
} else {
|
|
||||||
target_size[idx] = _sizes[idx].as_int_unchecked();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return {Shape(self.scalar_type(), target_size)};
|
return {Shape(self.scalar_type(), target_size)};
|
||||||
|
@ -30,22 +30,22 @@ py::handle type_caster<c10::SymInt>::cast(
|
|||||||
c10::SymInt si,
|
c10::SymInt si,
|
||||||
return_value_policy /* policy */,
|
return_value_policy /* policy */,
|
||||||
handle /* parent */) {
|
handle /* parent */) {
|
||||||
if (si.is_symbolic()) {
|
if (auto m = si.maybe_as_int()) {
|
||||||
auto* py_node =
|
return py::cast(*m).release();
|
||||||
dynamic_cast<torch::impl::PythonSymNodeImpl*>(si.toSymNodeImpl().get());
|
} else {
|
||||||
|
auto* py_node = dynamic_cast<torch::impl::PythonSymNodeImpl*>(
|
||||||
|
si.toSymNodeImplUnowned());
|
||||||
if (py_node) {
|
if (py_node) {
|
||||||
// Return the Python directly (unwrap)
|
// Return the Python directly (unwrap)
|
||||||
return torch::get_symint_class()(py_node->getPyObj()).release();
|
return torch::get_symint_class()(py_node->getPyObj()).release();
|
||||||
} else {
|
} else {
|
||||||
// Wrap the C++ into Python
|
// Wrap the C++ into Python
|
||||||
auto inner = py::cast(si.toSymNodeImpl());
|
auto inner = py::cast(si.toSymNode());
|
||||||
if (!inner) {
|
if (!inner) {
|
||||||
throw python_error();
|
throw python_error();
|
||||||
}
|
}
|
||||||
return torch::get_symint_class()(inner).release();
|
return torch::get_symint_class()(inner).release();
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
return py::cast(si.as_int_unchecked()).release();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -494,12 +494,12 @@ inline std::vector<int64_t> PythonArgs::intlist(int i) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
inline PyObject* toPyObject(c10::SymInt symint) {
|
inline PyObject* toPyObject(c10::SymInt symint) {
|
||||||
if (symint.is_symbolic()) {
|
if (auto m = symint.maybe_as_int()) {
|
||||||
|
return THPUtils_packInt64(*m);
|
||||||
|
} else {
|
||||||
auto r = py::cast(symint).release().ptr();
|
auto r = py::cast(symint).release().ptr();
|
||||||
TORCH_INTERNAL_ASSERT(r);
|
TORCH_INTERNAL_ASSERT(r);
|
||||||
return r;
|
return r;
|
||||||
} else {
|
|
||||||
return THPUtils_packInt64(symint.as_int_unchecked());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user