mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add SingletonSymIntNode (#107089)
Adds `SingletonSymNodeImpl` (alternatively, `SkolemSymNodeImpl`). This is a int-like object that only allows the`eq` operation; any other operation produces an error. The main complexity is that we require operations that dispatch to SymNode must take and return SymNodes, but when performing operations involving `SingletonSymNodeImpl`, operations involving SymNode can return non-SymNode bools. For more discussion see [here](https://docs.google.com/document/d/18iqMdnHlUnvoTz4BveBbyWFi_tCRmFoqMFdBHKmCm_k/edit) - Introduce `ConstantSymNodeImpl` a generalization of `LargeNegativeIntSymNodeImpl` and replace usage of `LargeNegativeIntSymNodeImpl` in SymInt. - Also use ConstantSymNodeImpl to enable SymBool to store its data on a SymNode. Remove the assumption that if SymBool holds a non-null SymNode, it must be symbolic. Pull Request resolved: https://github.com/pytorch/pytorch/pull/107089 Approved by: https://github.com/ezyang ghstack dependencies: #107839
This commit is contained in:
committed by
PyTorch MergeBot
parent
a41d15e458
commit
d7130e9704
@ -620,12 +620,12 @@ public:
|
||||
c10::SymFloat toSymFloat() const&;
|
||||
|
||||
IValue(c10::SymBool i) {
|
||||
if (i.is_symbolic()) {
|
||||
if (auto mi = i.maybe_as_bool()) {
|
||||
tag = Tag::Bool;
|
||||
payload.u.as_int = *mi;
|
||||
} else {
|
||||
tag = Tag::SymBool;
|
||||
payload.u.as_intrusive_ptr = i.toSymNodeImpl().release();
|
||||
} else {
|
||||
tag = Tag::Bool;
|
||||
payload.u.as_bool = i.as_bool_unchecked();
|
||||
}
|
||||
}
|
||||
|
||||
|
23
c10/core/ConstantSymNodeImpl.cpp
Normal file
23
c10/core/ConstantSymNodeImpl.cpp
Normal file
@ -0,0 +1,23 @@
|
||||
#include <c10/core/ConstantSymNodeImpl.h>
|
||||
|
||||
namespace c10 {
|
||||
// Temporary hack to avoid having to implement multiple dispatch for now
|
||||
// Currently even if we have this method, we still raise an error when we get
|
||||
// to SingletonSymNode::eq since comparing with non-singleton is disallowed.
|
||||
// However, we may change that behavior in the future.
|
||||
template <typename T>
|
||||
c10::SymNode ConstantSymNodeImpl<T>::eq(const c10::SymNode& other) {
|
||||
TORCH_INTERNAL_ASSERT(other->singleton_int().has_value());
|
||||
c10::raw::intrusive_ptr::incref(this);
|
||||
return other->eq(c10::intrusive_ptr<ConstantSymNodeImpl<T>>::reclaim(this));
|
||||
}
|
||||
template <typename T>
|
||||
c10::SymNode ConstantSymNodeImpl<T>::ne(const c10::SymNode& other) {
|
||||
TORCH_INTERNAL_ASSERT(other->singleton_int().has_value());
|
||||
c10::raw::intrusive_ptr::incref(this);
|
||||
return other->ne(c10::intrusive_ptr<ConstantSymNodeImpl<T>>::reclaim(this));
|
||||
}
|
||||
|
||||
template class ConstantSymNodeImpl<bool>;
|
||||
template class ConstantSymNodeImpl<int64_t>;
|
||||
} // namespace c10
|
79
c10/core/ConstantSymNodeImpl.h
Normal file
79
c10/core/ConstantSymNodeImpl.h
Normal file
@ -0,0 +1,79 @@
|
||||
#include <c10/core/SymNodeImpl.h>
|
||||
#include <c10/util/variant.h>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
// Unlike other SymNodeImpl, this cannot be "dispatched" conventionally,
|
||||
// as it typically needs to defer to another SymNodeImpl
|
||||
//
|
||||
// Can either represent a bool, int (don't support float yet) this is useful
|
||||
// for representing otherwise unrepresentable large negative integer constant.
|
||||
template <typename T>
|
||||
class C10_API ConstantSymNodeImpl : public SymNodeImpl {
|
||||
static_assert(
|
||||
std::is_same<T, int64_t>::value || std::is_same<T, bool>::value,
|
||||
"ConstantSymNodeImpl can only accept int64_t or bool types");
|
||||
|
||||
public:
|
||||
ConstantSymNodeImpl(T val) : value_(val) {}
|
||||
|
||||
bool is_int() override {
|
||||
return std::is_same<T, int64_t>::value;
|
||||
}
|
||||
bool is_bool() override {
|
||||
return std::is_same<T, bool>::value;
|
||||
}
|
||||
bool is_float() override {
|
||||
return false;
|
||||
}
|
||||
int64_t guard_int(const char* file, int64_t line) override {
|
||||
TORCH_CHECK(is_int(), "not an int");
|
||||
return int_();
|
||||
}
|
||||
bool guard_bool(const char* file, int64_t line) override {
|
||||
TORCH_CHECK(is_bool(), "not a bool");
|
||||
return bool_();
|
||||
}
|
||||
double guard_float(const char* file, int64_t line) override {
|
||||
TORCH_CHECK(false, "not a float");
|
||||
}
|
||||
int64_t int_() override {
|
||||
TORCH_CHECK(is_int(), "not an int");
|
||||
return c10::get<int64_t>(value_);
|
||||
}
|
||||
bool bool_() override {
|
||||
TORCH_CHECK(is_bool(), "not a bool");
|
||||
return c10::get<bool>(value_);
|
||||
}
|
||||
bool has_hint() override {
|
||||
return true;
|
||||
}
|
||||
c10::SymNode eq(const c10::SymNode& other) override;
|
||||
c10::SymNode ne(const c10::SymNode& other) override;
|
||||
std::string str() override {
|
||||
if (is_int()) {
|
||||
return std::to_string(c10::get<int64_t>(value_));
|
||||
} else {
|
||||
return c10::get<bool>(value_) ? "true" : "false";
|
||||
}
|
||||
}
|
||||
c10::optional<int64_t> constant_int() override {
|
||||
if (is_int()) {
|
||||
return c10::get<int64_t>(value_);
|
||||
} else {
|
||||
return c10::nullopt;
|
||||
}
|
||||
}
|
||||
c10::optional<bool> constant_bool() override {
|
||||
if (is_bool()) {
|
||||
return c10::get<bool>(value_);
|
||||
} else {
|
||||
return c10::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
c10::variant<int64_t, bool> value_;
|
||||
};
|
||||
|
||||
} // namespace c10
|
@ -1,50 +0,0 @@
|
||||
#include <c10/core/SymNodeImpl.h>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
// Represents an otherwise unrepresentable large negative integer constant.
|
||||
// Unlike other SymNodeImpl, this cannot be "dispatched" conventionally,
|
||||
// as it typically needs to defer to another SymNodeImpl
|
||||
class C10_API LargeNegativeIntSymNodeImpl : public SymNodeImpl {
|
||||
public:
|
||||
LargeNegativeIntSymNodeImpl(int64_t val) : val_(val) {}
|
||||
|
||||
bool is_int() override {
|
||||
return true;
|
||||
};
|
||||
bool is_bool() override {
|
||||
return false;
|
||||
};
|
||||
bool is_float() override {
|
||||
return false;
|
||||
};
|
||||
int64_t guard_int(const char* file, int64_t line) override {
|
||||
return val_;
|
||||
};
|
||||
bool guard_bool(const char* file, int64_t line) override {
|
||||
TORCH_CHECK(false, "not a bool");
|
||||
};
|
||||
double guard_float(const char* file, int64_t line) override {
|
||||
TORCH_CHECK(false, "not a float");
|
||||
};
|
||||
int64_t int_() override {
|
||||
return true;
|
||||
};
|
||||
bool bool_() override {
|
||||
return false;
|
||||
};
|
||||
bool has_hint() override {
|
||||
return true;
|
||||
};
|
||||
std::string str() override {
|
||||
return std::to_string(val_);
|
||||
};
|
||||
int64_t large_negative_int() override {
|
||||
return val_;
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t val_;
|
||||
};
|
||||
|
||||
} // namespace c10
|
@ -300,12 +300,12 @@ class C10_API Scalar {
|
||||
}
|
||||
|
||||
Scalar(c10::SymBool sb) {
|
||||
if (sb.is_symbolic()) {
|
||||
if (auto m = sb.maybe_as_bool()) {
|
||||
tag = Tag::HAS_b;
|
||||
v.i = *m;
|
||||
} else {
|
||||
tag = Tag::HAS_sb;
|
||||
v.p = std::move(sb).release();
|
||||
} else {
|
||||
tag = Tag::HAS_b;
|
||||
v.d = sb.as_bool_unchecked();
|
||||
}
|
||||
}
|
||||
|
||||
|
3
c10/core/SingletonSymNodeImpl.cpp
Normal file
3
c10/core/SingletonSymNodeImpl.cpp
Normal file
@ -0,0 +1,3 @@
|
||||
#include <c10/core/SingletonSymNodeImpl.h>
|
||||
|
||||
namespace c10 {} // namespace c10
|
123
c10/core/SingletonSymNodeImpl.h
Normal file
123
c10/core/SingletonSymNodeImpl.h
Normal file
@ -0,0 +1,123 @@
|
||||
#include <c10/core/ConstantSymNodeImpl.h>
|
||||
#include <c10/core/SymBool.h>
|
||||
#include <c10/core/SymNodeImpl.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
// An int-like object that only defines the equality operator.
|
||||
class C10_API SingletonSymNodeImpl : public SymNodeImpl {
|
||||
public:
|
||||
// CAUTION: you should probably not be constructing these directly; please
|
||||
// the higher-level API in python instead (TODO: actually introduce that).
|
||||
explicit SingletonSymNodeImpl(int64_t val) : val_(val) {}
|
||||
|
||||
bool bool_() override {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool is_int() override {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool is_float() override {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool is_bool() override {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool has_hint() override {
|
||||
return true;
|
||||
}
|
||||
|
||||
c10::SymNode wrap_int(int64_t num) override {
|
||||
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<int64_t>>(num));
|
||||
};
|
||||
|
||||
int64_t guard_int(const char* file, int64_t line) override {
|
||||
// TODO: when is this used?
|
||||
TORCH_CHECK(false);
|
||||
}
|
||||
|
||||
double guard_float(const char* file, int64_t line) override {
|
||||
TORCH_CHECK(false, "not a float");
|
||||
}
|
||||
|
||||
bool guard_bool(const char* file, int64_t line) override {
|
||||
TORCH_CHECK(false, "not a bool");
|
||||
}
|
||||
|
||||
int64_t int_() override {
|
||||
// TODO: when is this used?
|
||||
TORCH_CHECK(false);
|
||||
}
|
||||
|
||||
std::string str() override {
|
||||
return "j" + std::to_string(val_);
|
||||
}
|
||||
|
||||
c10::SymNode eq(const c10::SymNode& other) override {
|
||||
c10::optional<int64_t> c = other->singleton_int();
|
||||
TORCH_CHECK(
|
||||
c,
|
||||
"SingletonSymNode can only be compared with SingletonSymNode, but got ",
|
||||
other->str());
|
||||
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(val_ == *c));
|
||||
}
|
||||
|
||||
c10::SymNode ne(const c10::SymNode& other) override {
|
||||
c10::optional<int64_t> c = other->singleton_int();
|
||||
TORCH_CHECK(
|
||||
c,
|
||||
"SingletonSymNode can only be compared with SingletonSymNode, but got ",
|
||||
other->str());
|
||||
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(val_ != *c));
|
||||
}
|
||||
|
||||
c10::optional<int64_t> singleton_int() override {
|
||||
return val_;
|
||||
}
|
||||
|
||||
#define DEFINE_BINARY_NOT_SUPPORTED(name) \
|
||||
c10::SymNode name(const c10::SymNode& other) override { \
|
||||
TORCH_CHECK(false, #name " not supported by SingletonSymNode"); \
|
||||
}
|
||||
|
||||
DEFINE_BINARY_NOT_SUPPORTED(add)
|
||||
DEFINE_BINARY_NOT_SUPPORTED(sub)
|
||||
DEFINE_BINARY_NOT_SUPPORTED(mul)
|
||||
DEFINE_BINARY_NOT_SUPPORTED(truediv)
|
||||
DEFINE_BINARY_NOT_SUPPORTED(pow)
|
||||
DEFINE_BINARY_NOT_SUPPORTED(floordiv)
|
||||
DEFINE_BINARY_NOT_SUPPORTED(mod)
|
||||
DEFINE_BINARY_NOT_SUPPORTED(gt)
|
||||
DEFINE_BINARY_NOT_SUPPORTED(lt)
|
||||
DEFINE_BINARY_NOT_SUPPORTED(ge)
|
||||
DEFINE_BINARY_NOT_SUPPORTED(sym_min)
|
||||
DEFINE_BINARY_NOT_SUPPORTED(sym_max)
|
||||
DEFINE_BINARY_NOT_SUPPORTED(sym_and)
|
||||
DEFINE_BINARY_NOT_SUPPORTED(sym_or)
|
||||
|
||||
#undef DEFINE_BINARY_NOT_SUPPORTED
|
||||
|
||||
#define DEFINE_NOT_SUPPORTED(name) \
|
||||
c10::SymNode name() override { \
|
||||
TORCH_CHECK(false, #name " is not supported by SingletonSymNode"); \
|
||||
}
|
||||
|
||||
DEFINE_NOT_SUPPORTED(sym_not)
|
||||
DEFINE_NOT_SUPPORTED(ceil)
|
||||
DEFINE_NOT_SUPPORTED(floor)
|
||||
DEFINE_NOT_SUPPORTED(neg)
|
||||
DEFINE_NOT_SUPPORTED(clone)
|
||||
DEFINE_NOT_SUPPORTED(sym_float)
|
||||
|
||||
#undef DEFINE_NOT_SUPPORTED
|
||||
|
||||
private:
|
||||
int64_t val_;
|
||||
};
|
||||
|
||||
} // namespace c10
|
@ -1,3 +1,4 @@
|
||||
#include <c10/core/ConstantSymNodeImpl.h>
|
||||
#include <c10/core/SymBool.h>
|
||||
#include <c10/core/SymNodeImpl.h>
|
||||
#include <array>
|
||||
@ -6,87 +7,76 @@
|
||||
namespace c10 {
|
||||
|
||||
SymNode SymBool::toSymNodeImpl() const {
|
||||
TORCH_CHECK(is_symbolic());
|
||||
TORCH_CHECK(is_heap_allocated());
|
||||
return SymNode::reclaim_copy(toSymNodeImplUnowned());
|
||||
}
|
||||
|
||||
SymNode SymBool::wrap_node(const SymNode& base) const {
|
||||
if (is_symbolic()) {
|
||||
return toSymNodeImpl();
|
||||
if (auto ma = maybe_as_bool()) {
|
||||
return base->wrap_bool(*ma);
|
||||
} else {
|
||||
return base->wrap_bool(as_bool_unchecked());
|
||||
return toSymNodeImpl();
|
||||
}
|
||||
}
|
||||
|
||||
static std::array<SymNode, 2> normalize_symbools(
|
||||
const SymBool& a_,
|
||||
const SymBool& b_) {
|
||||
SymNode a, b;
|
||||
if (a_.is_symbolic())
|
||||
a = a_.toSymNodeImpl();
|
||||
if (b_.is_symbolic())
|
||||
b = b_.toSymNodeImpl();
|
||||
#define DEFINE_BINARY(API, OP, METHOD, RET) \
|
||||
RET SymBool::API(const SymBool& sci) const { \
|
||||
if (auto ma = maybe_as_bool()) { \
|
||||
if (auto mb = sci.maybe_as_bool()) { \
|
||||
return RET(OP(*ma, *mb)); \
|
||||
} else { \
|
||||
auto b = sci.toSymNodeImpl(); \
|
||||
return RET(b->wrap_bool(*ma)->METHOD(b)); \
|
||||
} \
|
||||
} else { \
|
||||
if (auto mb = sci.maybe_as_bool()) { \
|
||||
auto a = toSymNodeImplUnowned(); \
|
||||
return RET(a->METHOD(a->wrap_bool(*mb))); \
|
||||
} else { \
|
||||
return RET(toSymNodeImplUnowned()->METHOD(sci.toSymNodeImpl())); \
|
||||
} \
|
||||
} \
|
||||
}
|
||||
|
||||
SymNodeImpl* common = a ? a.get() : b.get();
|
||||
if (!a) {
|
||||
a = common->wrap_bool(a_.as_bool_unchecked());
|
||||
}
|
||||
if (!b) {
|
||||
b = common->wrap_bool(b_.as_bool_unchecked());
|
||||
}
|
||||
return {std::move(a), std::move(b)};
|
||||
}
|
||||
|
||||
SymBool SymBool::sym_and(const SymBool& sci) const {
|
||||
if (!is_symbolic() && !sci.is_symbolic()) {
|
||||
return SymBool(data_ && sci.data_);
|
||||
}
|
||||
auto res = normalize_symbools(*this, sci);
|
||||
return SymBool(res[0]->sym_and(res[1]));
|
||||
}
|
||||
|
||||
SymBool SymBool::sym_or(const SymBool& sci) const {
|
||||
if (!is_symbolic() && !sci.is_symbolic()) {
|
||||
return SymBool(data_ || sci.data_);
|
||||
}
|
||||
auto res = normalize_symbools(*this, sci);
|
||||
return SymBool(res[0]->sym_or(res[1]));
|
||||
}
|
||||
// clang-format off
|
||||
DEFINE_BINARY(sym_and, std::logical_and<>(), sym_and, SymBool)
|
||||
DEFINE_BINARY(sym_or, std::logical_or<>(), sym_or, SymBool)
|
||||
// clang-format on
|
||||
|
||||
SymBool SymBool::sym_not() const {
|
||||
if (!is_symbolic()) {
|
||||
return SymBool(!data_);
|
||||
if (auto ma = maybe_as_bool()) {
|
||||
return SymBool(!*ma);
|
||||
}
|
||||
return SymBool(toSymNodeImpl()->sym_not());
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const SymBool& s) {
|
||||
if (s.is_symbolic()) {
|
||||
os << s.toSymNodeImpl()->str();
|
||||
if (auto ma = s.maybe_as_bool()) {
|
||||
os << *ma;
|
||||
} else {
|
||||
os << s.as_bool_unchecked();
|
||||
os << s.toSymNodeImpl()->str();
|
||||
}
|
||||
return os;
|
||||
}
|
||||
|
||||
bool SymBool::guard_bool(const char* file, int64_t line) const {
|
||||
if (!is_symbolic()) {
|
||||
return data_;
|
||||
if (auto ma = maybe_as_bool()) {
|
||||
return *ma;
|
||||
}
|
||||
SymNode a = toSymNodeImpl();
|
||||
return a->guard_bool(file, line);
|
||||
}
|
||||
|
||||
bool SymBool::expect_true(const char* file, int64_t line) const {
|
||||
if (!is_symbolic()) {
|
||||
return data_;
|
||||
if (auto ma = maybe_as_bool()) {
|
||||
return *ma;
|
||||
}
|
||||
SymNode a = toSymNodeImpl();
|
||||
return a->expect_true(file, line);
|
||||
}
|
||||
|
||||
bool SymBool::has_hint() const {
|
||||
if (!is_symbolic()) {
|
||||
if (auto ma = maybe_as_bool()) {
|
||||
return true;
|
||||
}
|
||||
return toSymNodeImpl()->has_hint();
|
||||
|
@ -23,15 +23,16 @@ class C10_API SymBool {
|
||||
return std::move(ptr_).release();
|
||||
}
|
||||
|
||||
// Only valid if is_symbolic()
|
||||
// Only valid if is_heap_allocated()
|
||||
SymNode toSymNodeImpl() const;
|
||||
|
||||
// Guaranteed to return a SymNode, wrapping using base if necessary
|
||||
SymNode wrap_node(const SymNode& base) const;
|
||||
|
||||
bool expect_bool() const {
|
||||
TORCH_CHECK(!is_symbolic());
|
||||
return data_;
|
||||
c10::optional<bool> c = maybe_as_bool();
|
||||
TORCH_CHECK(c.has_value());
|
||||
return *c;
|
||||
}
|
||||
|
||||
SymBool sym_and(const SymBool&) const;
|
||||
@ -56,14 +57,21 @@ class C10_API SymBool {
|
||||
|
||||
bool has_hint() const;
|
||||
|
||||
C10_ALWAYS_INLINE bool is_symbolic() const {
|
||||
return ptr_;
|
||||
}
|
||||
|
||||
bool as_bool_unchecked() const {
|
||||
return data_;
|
||||
}
|
||||
|
||||
c10::optional<bool> maybe_as_bool() const {
|
||||
if (!is_heap_allocated()) {
|
||||
return c10::make_optional(data_);
|
||||
}
|
||||
return toSymNodeImplUnowned()->constant_bool();
|
||||
}
|
||||
|
||||
bool is_heap_allocated() const {
|
||||
return ptr_;
|
||||
}
|
||||
|
||||
private:
|
||||
// TODO: optimize to union
|
||||
bool data_;
|
||||
|
@ -1,4 +1,4 @@
|
||||
#include <c10/core/LargeNegativeIntSymNodeImpl.h>
|
||||
#include <c10/core/ConstantSymNodeImpl.h>
|
||||
#include <c10/core/SymFloat.h>
|
||||
#include <c10/core/SymInt.h>
|
||||
#include <c10/core/SymNodeImpl.h>
|
||||
@ -13,7 +13,7 @@ namespace c10 {
|
||||
// Postcondition: invariants on SymInt are fixed
|
||||
void SymInt::promote_to_negative() {
|
||||
auto s =
|
||||
SymInt(SymNode(c10::make_intrusive<LargeNegativeIntSymNodeImpl>(data_)));
|
||||
SymInt(SymNode(c10::make_intrusive<ConstantSymNodeImpl<int64_t>>(data_)));
|
||||
// Similar to move operator=, but do NOT release data_
|
||||
data_ = s.data_;
|
||||
s.data_ = 0;
|
||||
|
@ -145,10 +145,10 @@ class C10_API SymInt {
|
||||
// number can be used to diagnose overspecialization.
|
||||
int64_t guard_int(const char* file, int64_t line) const;
|
||||
|
||||
// Distinguish actual symbolic values from large negative integers.
|
||||
// Distinguish actual symbolic values from constants stored on the heap
|
||||
bool is_symbolic() const {
|
||||
return is_heap_allocated() &&
|
||||
toSymNodeImplUnowned()->large_negative_int() == 0;
|
||||
!toSymNodeImplUnowned()->constant_int().has_value();
|
||||
}
|
||||
|
||||
// N.B. It's important to keep this definition in the header
|
||||
@ -215,15 +215,10 @@ class C10_API SymInt {
|
||||
return c10::make_optional(data_);
|
||||
}
|
||||
auto* node = toSymNodeImplUnowned();
|
||||
int64_t c = node->large_negative_int();
|
||||
if (c != 0) {
|
||||
return c10::make_optional(c);
|
||||
if (auto c = node->constant_int()) {
|
||||
return c;
|
||||
}
|
||||
c10::optional<int64_t> d = node->maybe_as_int();
|
||||
if (d.has_value()) {
|
||||
return d;
|
||||
}
|
||||
return c10::nullopt;
|
||||
return node->maybe_as_int();
|
||||
}
|
||||
|
||||
// Return whether the integer is directly coercible to a SymInt
|
||||
|
@ -169,8 +169,14 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
|
||||
virtual std::string str() {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual int64_t large_negative_int() {
|
||||
return 0; // not a large negative int!
|
||||
virtual c10::optional<int64_t> singleton_int() {
|
||||
return c10::nullopt;
|
||||
}
|
||||
virtual c10::optional<int64_t> constant_int() {
|
||||
return c10::nullopt;
|
||||
}
|
||||
virtual c10::optional<bool> constant_bool() {
|
||||
return c10::nullopt;
|
||||
}
|
||||
virtual c10::optional<int64_t> maybe_as_int() {
|
||||
return c10::nullopt;
|
||||
@ -178,7 +184,7 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
|
||||
std::ostream& operator<<(std::ostream& os) {
|
||||
os << str();
|
||||
return os;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace c10
|
||||
|
@ -1,5 +1,6 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <c10/core/SingletonSymNodeImpl.h>
|
||||
#include <c10/core/SymInt.h>
|
||||
#include <c10/core/SymNodeImpl.h>
|
||||
|
||||
@ -21,4 +22,32 @@ TEST(SymIntTest, ConcreteInts) {
|
||||
TEST(SymIntTest, CheckRange) {
|
||||
EXPECT_FALSE(SymInt::check_range(INT64_MIN));
|
||||
}
|
||||
|
||||
TEST(SymIntTest, SingletonSymNode) {
|
||||
auto a = c10::SymInt(
|
||||
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1)));
|
||||
auto b = c10::SymInt(
|
||||
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1)));
|
||||
auto c = c10::SymInt(
|
||||
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(2)));
|
||||
auto d = c10::SymInt(3);
|
||||
|
||||
ASSERT_TRUE(a == a);
|
||||
ASSERT_TRUE(a == b);
|
||||
ASSERT_FALSE(a != a);
|
||||
ASSERT_FALSE(a != b);
|
||||
ASSERT_FALSE(a == c);
|
||||
ASSERT_TRUE(a != c);
|
||||
|
||||
// Tentaively throw an error when comparing with a non-singleton, this is not
|
||||
// necessarily the right behavior.
|
||||
ASSERT_THROW((void)(a == d), c10::Error);
|
||||
ASSERT_THROW((void)(a != d), c10::Error);
|
||||
ASSERT_THROW((void)(d == a), c10::Error);
|
||||
ASSERT_THROW((void)(d != a), c10::Error);
|
||||
|
||||
ASSERT_THROW((void)(a >= b), c10::Error); // "not supported by..."
|
||||
ASSERT_THROW((void)(a >= d), c10::Error); // "not supported by..."
|
||||
ASSERT_THROW((void)(d >= a), c10::Error); // "NYI"
|
||||
}
|
||||
#endif
|
||||
|
@ -97,7 +97,7 @@ void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) {
|
||||
// no op, there is nothing to tag
|
||||
break;
|
||||
case c10::SymBoolType::Kind:
|
||||
TORCH_CHECK(!w.value.toSymBool().is_symbolic());
|
||||
TORCH_CHECK(!w.value.toSymBool().is_heap_allocated());
|
||||
// no op, there is nothing to tag
|
||||
break;
|
||||
case DynamicType::Kind:
|
||||
|
@ -99,14 +99,14 @@ py::handle type_caster<c10::SymBool>::cast(
|
||||
const c10::SymBool& si,
|
||||
return_value_policy /* policy */,
|
||||
handle /* parent */) {
|
||||
if (si.is_symbolic()) {
|
||||
if (auto m = si.maybe_as_bool()) {
|
||||
return py::cast(*m).release();
|
||||
} else {
|
||||
// TODO: generalize this to work with C++ backed class
|
||||
auto* py_node =
|
||||
dynamic_cast<torch::impl::PythonSymNodeImpl*>(si.toSymNodeImpl().get());
|
||||
TORCH_INTERNAL_ASSERT(py_node);
|
||||
return torch::get_symbool_class()(py_node->getPyObj()).release();
|
||||
} else {
|
||||
return py::cast(si.as_bool_unchecked()).release();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -14,6 +14,7 @@
|
||||
#include <torch/csrc/autograd/python_variable.h>
|
||||
#include <torch/csrc/jit/python/pybind_utils.h>
|
||||
|
||||
#include <c10/core/SingletonSymNodeImpl.h>
|
||||
#include <c10/util/flat_hash_map.h>
|
||||
#include <pybind11/operators.h>
|
||||
#include <pybind11/stl.h>
|
||||
@ -738,6 +739,11 @@ void initDispatchBindings(PyObject* module) {
|
||||
include_set.has(c10::DispatchKey::FuncTorchDynamicLayerFrontMode) ||
|
||||
include_set.has(c10::DispatchKey::FuncTorchDynamicLayerBackMode));
|
||||
});
|
||||
|
||||
m.def("_get_singleton_int", [](int64_t data) {
|
||||
return c10::SymInt(
|
||||
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(data)));
|
||||
});
|
||||
}
|
||||
|
||||
// TODO: dedupe with the kernel
|
||||
|
Reference in New Issue
Block a user