Add SingletonSymIntNode (#107089)

Adds `SingletonSymNodeImpl` (alternatively, `SkolemSymNodeImpl`). This is a int-like object that only allows  the`eq` operation; any other operation produces an error.

The main complexity is that we require operations that dispatch to SymNode must take and return SymNodes, but when performing operations involving `SingletonSymNodeImpl`, operations involving SymNode can return non-SymNode bools.  For more discussion see [here](https://docs.google.com/document/d/18iqMdnHlUnvoTz4BveBbyWFi_tCRmFoqMFdBHKmCm_k/edit)
- Introduce `ConstantSymNodeImpl` a generalization of `LargeNegativeIntSymNodeImpl` and replace usage of `LargeNegativeIntSymNodeImpl`  in SymInt.
- Also use ConstantSymNodeImpl to enable SymBool to store its data on a SymNode. Remove the  assumption that if SymBool holds a non-null SymNode, it must be symbolic.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107089
Approved by: https://github.com/ezyang
ghstack dependencies: #107839
This commit is contained in:
soulitzer
2023-08-24 12:34:18 -04:00
committed by PyTorch MergeBot
parent a41d15e458
commit d7130e9704
16 changed files with 343 additions and 131 deletions

View File

@ -620,12 +620,12 @@ public:
c10::SymFloat toSymFloat() const&;
IValue(c10::SymBool i) {
if (i.is_symbolic()) {
if (auto mi = i.maybe_as_bool()) {
tag = Tag::Bool;
payload.u.as_int = *mi;
} else {
tag = Tag::SymBool;
payload.u.as_intrusive_ptr = i.toSymNodeImpl().release();
} else {
tag = Tag::Bool;
payload.u.as_bool = i.as_bool_unchecked();
}
}

View File

@ -0,0 +1,23 @@
#include <c10/core/ConstantSymNodeImpl.h>
namespace c10 {
// Temporary hack to avoid having to implement multiple dispatch for now
// Currently even if we have this method, we still raise an error when we get
// to SingletonSymNode::eq since comparing with non-singleton is disallowed.
// However, we may change that behavior in the future.
template <typename T>
c10::SymNode ConstantSymNodeImpl<T>::eq(const c10::SymNode& other) {
TORCH_INTERNAL_ASSERT(other->singleton_int().has_value());
c10::raw::intrusive_ptr::incref(this);
return other->eq(c10::intrusive_ptr<ConstantSymNodeImpl<T>>::reclaim(this));
}
template <typename T>
c10::SymNode ConstantSymNodeImpl<T>::ne(const c10::SymNode& other) {
TORCH_INTERNAL_ASSERT(other->singleton_int().has_value());
c10::raw::intrusive_ptr::incref(this);
return other->ne(c10::intrusive_ptr<ConstantSymNodeImpl<T>>::reclaim(this));
}
template class ConstantSymNodeImpl<bool>;
template class ConstantSymNodeImpl<int64_t>;
} // namespace c10

View File

@ -0,0 +1,79 @@
#include <c10/core/SymNodeImpl.h>
#include <c10/util/variant.h>
namespace c10 {
// Unlike other SymNodeImpl, this cannot be "dispatched" conventionally,
// as it typically needs to defer to another SymNodeImpl
//
// Can either represent a bool, int (don't support float yet) this is useful
// for representing otherwise unrepresentable large negative integer constant.
template <typename T>
class C10_API ConstantSymNodeImpl : public SymNodeImpl {
static_assert(
std::is_same<T, int64_t>::value || std::is_same<T, bool>::value,
"ConstantSymNodeImpl can only accept int64_t or bool types");
public:
ConstantSymNodeImpl(T val) : value_(val) {}
bool is_int() override {
return std::is_same<T, int64_t>::value;
}
bool is_bool() override {
return std::is_same<T, bool>::value;
}
bool is_float() override {
return false;
}
int64_t guard_int(const char* file, int64_t line) override {
TORCH_CHECK(is_int(), "not an int");
return int_();
}
bool guard_bool(const char* file, int64_t line) override {
TORCH_CHECK(is_bool(), "not a bool");
return bool_();
}
double guard_float(const char* file, int64_t line) override {
TORCH_CHECK(false, "not a float");
}
int64_t int_() override {
TORCH_CHECK(is_int(), "not an int");
return c10::get<int64_t>(value_);
}
bool bool_() override {
TORCH_CHECK(is_bool(), "not a bool");
return c10::get<bool>(value_);
}
bool has_hint() override {
return true;
}
c10::SymNode eq(const c10::SymNode& other) override;
c10::SymNode ne(const c10::SymNode& other) override;
std::string str() override {
if (is_int()) {
return std::to_string(c10::get<int64_t>(value_));
} else {
return c10::get<bool>(value_) ? "true" : "false";
}
}
c10::optional<int64_t> constant_int() override {
if (is_int()) {
return c10::get<int64_t>(value_);
} else {
return c10::nullopt;
}
}
c10::optional<bool> constant_bool() override {
if (is_bool()) {
return c10::get<bool>(value_);
} else {
return c10::nullopt;
}
}
private:
c10::variant<int64_t, bool> value_;
};
} // namespace c10

View File

@ -1,50 +0,0 @@
#include <c10/core/SymNodeImpl.h>
namespace c10 {
// Represents an otherwise unrepresentable large negative integer constant.
// Unlike other SymNodeImpl, this cannot be "dispatched" conventionally,
// as it typically needs to defer to another SymNodeImpl
class C10_API LargeNegativeIntSymNodeImpl : public SymNodeImpl {
public:
LargeNegativeIntSymNodeImpl(int64_t val) : val_(val) {}
bool is_int() override {
return true;
};
bool is_bool() override {
return false;
};
bool is_float() override {
return false;
};
int64_t guard_int(const char* file, int64_t line) override {
return val_;
};
bool guard_bool(const char* file, int64_t line) override {
TORCH_CHECK(false, "not a bool");
};
double guard_float(const char* file, int64_t line) override {
TORCH_CHECK(false, "not a float");
};
int64_t int_() override {
return true;
};
bool bool_() override {
return false;
};
bool has_hint() override {
return true;
};
std::string str() override {
return std::to_string(val_);
};
int64_t large_negative_int() override {
return val_;
}
private:
int64_t val_;
};
} // namespace c10

View File

@ -300,12 +300,12 @@ class C10_API Scalar {
}
Scalar(c10::SymBool sb) {
if (sb.is_symbolic()) {
if (auto m = sb.maybe_as_bool()) {
tag = Tag::HAS_b;
v.i = *m;
} else {
tag = Tag::HAS_sb;
v.p = std::move(sb).release();
} else {
tag = Tag::HAS_b;
v.d = sb.as_bool_unchecked();
}
}

View File

@ -0,0 +1,3 @@
#include <c10/core/SingletonSymNodeImpl.h>
namespace c10 {} // namespace c10

View File

@ -0,0 +1,123 @@
#include <c10/core/ConstantSymNodeImpl.h>
#include <c10/core/SymBool.h>
#include <c10/core/SymNodeImpl.h>
#include <iostream>
namespace c10 {
// An int-like object that only defines the equality operator.
class C10_API SingletonSymNodeImpl : public SymNodeImpl {
public:
// CAUTION: you should probably not be constructing these directly; please
// the higher-level API in python instead (TODO: actually introduce that).
explicit SingletonSymNodeImpl(int64_t val) : val_(val) {}
bool bool_() override {
return false;
}
bool is_int() override {
return true;
}
bool is_float() override {
return false;
}
bool is_bool() override {
return false;
}
bool has_hint() override {
return true;
}
c10::SymNode wrap_int(int64_t num) override {
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<int64_t>>(num));
};
int64_t guard_int(const char* file, int64_t line) override {
// TODO: when is this used?
TORCH_CHECK(false);
}
double guard_float(const char* file, int64_t line) override {
TORCH_CHECK(false, "not a float");
}
bool guard_bool(const char* file, int64_t line) override {
TORCH_CHECK(false, "not a bool");
}
int64_t int_() override {
// TODO: when is this used?
TORCH_CHECK(false);
}
std::string str() override {
return "j" + std::to_string(val_);
}
c10::SymNode eq(const c10::SymNode& other) override {
c10::optional<int64_t> c = other->singleton_int();
TORCH_CHECK(
c,
"SingletonSymNode can only be compared with SingletonSymNode, but got ",
other->str());
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(val_ == *c));
}
c10::SymNode ne(const c10::SymNode& other) override {
c10::optional<int64_t> c = other->singleton_int();
TORCH_CHECK(
c,
"SingletonSymNode can only be compared with SingletonSymNode, but got ",
other->str());
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(val_ != *c));
}
c10::optional<int64_t> singleton_int() override {
return val_;
}
#define DEFINE_BINARY_NOT_SUPPORTED(name) \
c10::SymNode name(const c10::SymNode& other) override { \
TORCH_CHECK(false, #name " not supported by SingletonSymNode"); \
}
DEFINE_BINARY_NOT_SUPPORTED(add)
DEFINE_BINARY_NOT_SUPPORTED(sub)
DEFINE_BINARY_NOT_SUPPORTED(mul)
DEFINE_BINARY_NOT_SUPPORTED(truediv)
DEFINE_BINARY_NOT_SUPPORTED(pow)
DEFINE_BINARY_NOT_SUPPORTED(floordiv)
DEFINE_BINARY_NOT_SUPPORTED(mod)
DEFINE_BINARY_NOT_SUPPORTED(gt)
DEFINE_BINARY_NOT_SUPPORTED(lt)
DEFINE_BINARY_NOT_SUPPORTED(ge)
DEFINE_BINARY_NOT_SUPPORTED(sym_min)
DEFINE_BINARY_NOT_SUPPORTED(sym_max)
DEFINE_BINARY_NOT_SUPPORTED(sym_and)
DEFINE_BINARY_NOT_SUPPORTED(sym_or)
#undef DEFINE_BINARY_NOT_SUPPORTED
#define DEFINE_NOT_SUPPORTED(name) \
c10::SymNode name() override { \
TORCH_CHECK(false, #name " is not supported by SingletonSymNode"); \
}
DEFINE_NOT_SUPPORTED(sym_not)
DEFINE_NOT_SUPPORTED(ceil)
DEFINE_NOT_SUPPORTED(floor)
DEFINE_NOT_SUPPORTED(neg)
DEFINE_NOT_SUPPORTED(clone)
DEFINE_NOT_SUPPORTED(sym_float)
#undef DEFINE_NOT_SUPPORTED
private:
int64_t val_;
};
} // namespace c10

View File

@ -1,3 +1,4 @@
#include <c10/core/ConstantSymNodeImpl.h>
#include <c10/core/SymBool.h>
#include <c10/core/SymNodeImpl.h>
#include <array>
@ -6,87 +7,76 @@
namespace c10 {
SymNode SymBool::toSymNodeImpl() const {
TORCH_CHECK(is_symbolic());
TORCH_CHECK(is_heap_allocated());
return SymNode::reclaim_copy(toSymNodeImplUnowned());
}
SymNode SymBool::wrap_node(const SymNode& base) const {
if (is_symbolic()) {
return toSymNodeImpl();
if (auto ma = maybe_as_bool()) {
return base->wrap_bool(*ma);
} else {
return base->wrap_bool(as_bool_unchecked());
return toSymNodeImpl();
}
}
static std::array<SymNode, 2> normalize_symbools(
const SymBool& a_,
const SymBool& b_) {
SymNode a, b;
if (a_.is_symbolic())
a = a_.toSymNodeImpl();
if (b_.is_symbolic())
b = b_.toSymNodeImpl();
#define DEFINE_BINARY(API, OP, METHOD, RET) \
RET SymBool::API(const SymBool& sci) const { \
if (auto ma = maybe_as_bool()) { \
if (auto mb = sci.maybe_as_bool()) { \
return RET(OP(*ma, *mb)); \
} else { \
auto b = sci.toSymNodeImpl(); \
return RET(b->wrap_bool(*ma)->METHOD(b)); \
} \
} else { \
if (auto mb = sci.maybe_as_bool()) { \
auto a = toSymNodeImplUnowned(); \
return RET(a->METHOD(a->wrap_bool(*mb))); \
} else { \
return RET(toSymNodeImplUnowned()->METHOD(sci.toSymNodeImpl())); \
} \
} \
}
SymNodeImpl* common = a ? a.get() : b.get();
if (!a) {
a = common->wrap_bool(a_.as_bool_unchecked());
}
if (!b) {
b = common->wrap_bool(b_.as_bool_unchecked());
}
return {std::move(a), std::move(b)};
}
SymBool SymBool::sym_and(const SymBool& sci) const {
if (!is_symbolic() && !sci.is_symbolic()) {
return SymBool(data_ && sci.data_);
}
auto res = normalize_symbools(*this, sci);
return SymBool(res[0]->sym_and(res[1]));
}
SymBool SymBool::sym_or(const SymBool& sci) const {
if (!is_symbolic() && !sci.is_symbolic()) {
return SymBool(data_ || sci.data_);
}
auto res = normalize_symbools(*this, sci);
return SymBool(res[0]->sym_or(res[1]));
}
// clang-format off
DEFINE_BINARY(sym_and, std::logical_and<>(), sym_and, SymBool)
DEFINE_BINARY(sym_or, std::logical_or<>(), sym_or, SymBool)
// clang-format on
SymBool SymBool::sym_not() const {
if (!is_symbolic()) {
return SymBool(!data_);
if (auto ma = maybe_as_bool()) {
return SymBool(!*ma);
}
return SymBool(toSymNodeImpl()->sym_not());
}
std::ostream& operator<<(std::ostream& os, const SymBool& s) {
if (s.is_symbolic()) {
os << s.toSymNodeImpl()->str();
if (auto ma = s.maybe_as_bool()) {
os << *ma;
} else {
os << s.as_bool_unchecked();
os << s.toSymNodeImpl()->str();
}
return os;
}
bool SymBool::guard_bool(const char* file, int64_t line) const {
if (!is_symbolic()) {
return data_;
if (auto ma = maybe_as_bool()) {
return *ma;
}
SymNode a = toSymNodeImpl();
return a->guard_bool(file, line);
}
bool SymBool::expect_true(const char* file, int64_t line) const {
if (!is_symbolic()) {
return data_;
if (auto ma = maybe_as_bool()) {
return *ma;
}
SymNode a = toSymNodeImpl();
return a->expect_true(file, line);
}
bool SymBool::has_hint() const {
if (!is_symbolic()) {
if (auto ma = maybe_as_bool()) {
return true;
}
return toSymNodeImpl()->has_hint();

View File

@ -23,15 +23,16 @@ class C10_API SymBool {
return std::move(ptr_).release();
}
// Only valid if is_symbolic()
// Only valid if is_heap_allocated()
SymNode toSymNodeImpl() const;
// Guaranteed to return a SymNode, wrapping using base if necessary
SymNode wrap_node(const SymNode& base) const;
bool expect_bool() const {
TORCH_CHECK(!is_symbolic());
return data_;
c10::optional<bool> c = maybe_as_bool();
TORCH_CHECK(c.has_value());
return *c;
}
SymBool sym_and(const SymBool&) const;
@ -56,14 +57,21 @@ class C10_API SymBool {
bool has_hint() const;
C10_ALWAYS_INLINE bool is_symbolic() const {
return ptr_;
}
bool as_bool_unchecked() const {
return data_;
}
c10::optional<bool> maybe_as_bool() const {
if (!is_heap_allocated()) {
return c10::make_optional(data_);
}
return toSymNodeImplUnowned()->constant_bool();
}
bool is_heap_allocated() const {
return ptr_;
}
private:
// TODO: optimize to union
bool data_;

View File

@ -1,4 +1,4 @@
#include <c10/core/LargeNegativeIntSymNodeImpl.h>
#include <c10/core/ConstantSymNodeImpl.h>
#include <c10/core/SymFloat.h>
#include <c10/core/SymInt.h>
#include <c10/core/SymNodeImpl.h>
@ -13,7 +13,7 @@ namespace c10 {
// Postcondition: invariants on SymInt are fixed
void SymInt::promote_to_negative() {
auto s =
SymInt(SymNode(c10::make_intrusive<LargeNegativeIntSymNodeImpl>(data_)));
SymInt(SymNode(c10::make_intrusive<ConstantSymNodeImpl<int64_t>>(data_)));
// Similar to move operator=, but do NOT release data_
data_ = s.data_;
s.data_ = 0;

View File

@ -145,10 +145,10 @@ class C10_API SymInt {
// number can be used to diagnose overspecialization.
int64_t guard_int(const char* file, int64_t line) const;
// Distinguish actual symbolic values from large negative integers.
// Distinguish actual symbolic values from constants stored on the heap
bool is_symbolic() const {
return is_heap_allocated() &&
toSymNodeImplUnowned()->large_negative_int() == 0;
!toSymNodeImplUnowned()->constant_int().has_value();
}
// N.B. It's important to keep this definition in the header
@ -215,15 +215,10 @@ class C10_API SymInt {
return c10::make_optional(data_);
}
auto* node = toSymNodeImplUnowned();
int64_t c = node->large_negative_int();
if (c != 0) {
return c10::make_optional(c);
if (auto c = node->constant_int()) {
return c;
}
c10::optional<int64_t> d = node->maybe_as_int();
if (d.has_value()) {
return d;
}
return c10::nullopt;
return node->maybe_as_int();
}
// Return whether the integer is directly coercible to a SymInt

View File

@ -169,8 +169,14 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
virtual std::string str() {
TORCH_CHECK(false, "NYI");
};
virtual int64_t large_negative_int() {
return 0; // not a large negative int!
virtual c10::optional<int64_t> singleton_int() {
return c10::nullopt;
}
virtual c10::optional<int64_t> constant_int() {
return c10::nullopt;
}
virtual c10::optional<bool> constant_bool() {
return c10::nullopt;
}
virtual c10::optional<int64_t> maybe_as_int() {
return c10::nullopt;
@ -178,7 +184,7 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
std::ostream& operator<<(std::ostream& os) {
os << str();
return os;
};
}
};
} // namespace c10

View File

@ -1,5 +1,6 @@
#include <gtest/gtest.h>
#include <c10/core/SingletonSymNodeImpl.h>
#include <c10/core/SymInt.h>
#include <c10/core/SymNodeImpl.h>
@ -21,4 +22,32 @@ TEST(SymIntTest, ConcreteInts) {
TEST(SymIntTest, CheckRange) {
EXPECT_FALSE(SymInt::check_range(INT64_MIN));
}
TEST(SymIntTest, SingletonSymNode) {
auto a = c10::SymInt(
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1)));
auto b = c10::SymInt(
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1)));
auto c = c10::SymInt(
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(2)));
auto d = c10::SymInt(3);
ASSERT_TRUE(a == a);
ASSERT_TRUE(a == b);
ASSERT_FALSE(a != a);
ASSERT_FALSE(a != b);
ASSERT_FALSE(a == c);
ASSERT_TRUE(a != c);
// Tentaively throw an error when comparing with a non-singleton, this is not
// necessarily the right behavior.
ASSERT_THROW((void)(a == d), c10::Error);
ASSERT_THROW((void)(a != d), c10::Error);
ASSERT_THROW((void)(d == a), c10::Error);
ASSERT_THROW((void)(d != a), c10::Error);
ASSERT_THROW((void)(a >= b), c10::Error); // "not supported by..."
ASSERT_THROW((void)(a >= d), c10::Error); // "not supported by..."
ASSERT_THROW((void)(d >= a), c10::Error); // "NYI"
}
#endif

View File

@ -97,7 +97,7 @@ void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) {
// no op, there is nothing to tag
break;
case c10::SymBoolType::Kind:
TORCH_CHECK(!w.value.toSymBool().is_symbolic());
TORCH_CHECK(!w.value.toSymBool().is_heap_allocated());
// no op, there is nothing to tag
break;
case DynamicType::Kind:

View File

@ -99,14 +99,14 @@ py::handle type_caster<c10::SymBool>::cast(
const c10::SymBool& si,
return_value_policy /* policy */,
handle /* parent */) {
if (si.is_symbolic()) {
if (auto m = si.maybe_as_bool()) {
return py::cast(*m).release();
} else {
// TODO: generalize this to work with C++ backed class
auto* py_node =
dynamic_cast<torch::impl::PythonSymNodeImpl*>(si.toSymNodeImpl().get());
TORCH_INTERNAL_ASSERT(py_node);
return torch::get_symbool_class()(py_node->getPyObj()).release();
} else {
return py::cast(si.as_bool_unchecked()).release();
}
}

View File

@ -14,6 +14,7 @@
#include <torch/csrc/autograd/python_variable.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <c10/core/SingletonSymNodeImpl.h>
#include <c10/util/flat_hash_map.h>
#include <pybind11/operators.h>
#include <pybind11/stl.h>
@ -738,6 +739,11 @@ void initDispatchBindings(PyObject* module) {
include_set.has(c10::DispatchKey::FuncTorchDynamicLayerFrontMode) ||
include_set.has(c10::DispatchKey::FuncTorchDynamicLayerBackMode));
});
m.def("_get_singleton_int", [](int64_t data) {
return c10::SymInt(
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(data)));
});
}
// TODO: dedupe with the kernel