[reland] Update singleton int to error when inequality relation is undefined (#110672)

reland of https://github.com/pytorch/pytorch/pull/110044
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110672
Approved by: https://github.com/ezyang
This commit is contained in:
soulitzer
2023-10-05 21:12:14 -04:00
committed by PyTorch MergeBot
parent 576b80d23e
commit 69ea214cc2
4 changed files with 126 additions and 71 deletions

View File

@ -1,3 +1,65 @@
#include <c10/core/SingletonSymNodeImpl.h>
#include <c10/core/SymNodeImpl.h>
#include <c10/util/Exception.h>
namespace c10 {} // namespace c10
namespace c10 {
namespace {
bool _eq(const char* op, c10::SymNodeImpl* lhs, c10::SymNodeImpl* rhs) {
TORCH_INTERNAL_ASSERT(lhs->singleton_int().has_value());
c10::optional<int64_t> c = rhs->singleton_int();
return c.has_value() && lhs->singleton_int() == *c;
}
bool _ge(const char* op, c10::SymNodeImpl* lhs, c10::SymNodeImpl* rhs) {
if (auto mb_si = lhs->singleton_int()) {
if (auto mb_si2 = rhs->singleton_int()) {
if (*mb_si == *mb_si2) {
return true;
}
TORCH_CHECK(false, "Singleton int ", op, ": Relation is indeterminate");
}
if (rhs->constant_int() && *rhs->constant_int() <= 2) {
return true;
}
TORCH_CHECK(false, "Singleton int ", op, ": Relation is indeterminate");
} else if (rhs->singleton_int()) {
if (lhs->constant_int() && *lhs->constant_int() < 2) {
return false;
}
TORCH_CHECK(false, "Singleton int ", op, ": Relation is indeterminate");
}
TORCH_INTERNAL_ASSERT(false, "expect at least one singleton");
}
} // namespace
c10::SymNode SingletonSymNodeImpl::eq(const c10::SymNode& other) {
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
_eq("eq", this, other.get())));
}
c10::SymNode SingletonSymNodeImpl::ne(const c10::SymNode& other) {
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
!_eq("ne", this, other.get())));
}
c10::SymNode SingletonSymNodeImpl::ge(const c10::SymNode& other) {
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
_ge("ge", this, other.get())));
}
c10::SymNode SingletonSymNodeImpl::gt(const c10::SymNode& other) {
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
!_ge("gt", other.get(), this)));
}
c10::SymNode SingletonSymNodeImpl::lt(const c10::SymNode& other) {
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
!_ge("lt", this, other.get())));
}
c10::SymNode SingletonSymNodeImpl::le(const c10::SymNode& other) {
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
_ge("le", other.get(), this)));
}
} // namespace c10

View File

@ -19,6 +19,7 @@ namespace c10 {
// we associate each raggedness pattern with an integer "id" that can be used as
// a proxy to evaluate equality. We also constrain the range of values for this
// as to enable inequality checks.
//
class C10_API SingletonSymNodeImpl : public SymNodeImpl {
public:
// CAUTION: you should probably not be constructing these directly; please
@ -69,55 +70,38 @@ class C10_API SingletonSymNodeImpl : public SymNodeImpl {
return "j" + std::to_string(val_);
}
c10::SymNode eq(const c10::SymNode& other) override {
c10::optional<int64_t> c = other->singleton_int();
bool ret = c.has_value() && val_ == *c;
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(ret));
}
c10::SymNode ne(const c10::SymNode& other) override {
c10::optional<int64_t> c = other->singleton_int();
bool ret = !c.has_value() || val_ != *c;
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(ret));
}
// It would be cool to have the ability to arbitrarily constrain the range of
// values as we do for unbacked symints. For now a useful default
// range seems to be [2, int64_t::max()] (1) since sizes are non-negative, and
// (2) we need to get past 0/1 specialization checks.
c10::SymNode ge(const c10::SymNode& other) override {
if (auto mb_si = other->singleton_int()) {
return SymNode(
c10::make_intrusive<ConstantSymNodeImpl<bool>>(val_ == *mb_si));
}
c10::optional<int64_t> c = other->constant_int();
TORCH_CHECK(c.has_value());
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(*c <= 2));
}
c10::SymNode gt(const c10::SymNode& other) override {
if (auto mb_si = other->singleton_int()) {
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(false));
}
c10::optional<int64_t> c = other->constant_int();
TORCH_CHECK(c.has_value());
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(*c < 2));
}
c10::SymNode lt(const c10::SymNode& other) override {
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(false));
}
c10::SymNode le(const c10::SymNode& other) override {
if (auto mb_si = other->singleton_int()) {
return SymNode(
c10::make_intrusive<ConstantSymNodeImpl<bool>>(val_ == *mb_si));
}
c10::optional<int64_t> c = other->constant_int();
TORCH_CHECK(c.has_value());
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
*c >= std::numeric_limits<int64_t>::max()));
}
// NOTE [ Inequalities with SingletonInt ]
//
// The semantics of SingletonInt when it comes to relations is that it is
// treated as integer known to be within a certain range,
//
// j0 \in [2, int64_t::max]
//
// allowing us to answer queries like j0 >= 1 (True), and j0 == 0 (False).
// This is a useful default range for the raggedness pattern of a jagged
// tensor (1) since sizes are non-negative, and (2) we need to get past 0/1
// specialization checks.
//
// [ Indeterminate inequalities error out ]
//
// Given the semantic defined above, certain relations like j0 < 3 are thus
// indeterminable. In our impl today, evaluating such relations error
//
// It may seem convenient to just define indeterminate relations to return
// False, but the implementation we maintain in parallel using sympy does not
// allow this.
//
// Sympy only allows overriding of Ge. The other relations (Lt, Gt, Le) are,
// by consequence, all derived from Ge e.g., Lt(a, b) := !Ge(a, b). This
// would mean that means that if we define the indeterminate j0 >= 3 to be
// False, the also indeterminate j0 < 3 will be evaluated to be True!
//
c10::SymNode eq(const c10::SymNode& other) override;
c10::SymNode ne(const c10::SymNode& other) override;
c10::SymNode ge(const c10::SymNode& other) override;
c10::SymNode gt(const c10::SymNode& other) override;
c10::SymNode lt(const c10::SymNode& other) override;
c10::SymNode le(const c10::SymNode& other) override;
c10::optional<int64_t> singleton_int() override {
return val_;

View File

@ -48,48 +48,56 @@ TEST(SymIntTest, SingletonSymNode) {
ASSERT_TRUE(a >= a);
ASSERT_TRUE(a >= b);
ASSERT_TRUE(b >= a);
ASSERT_FALSE(a >= c);
ASSERT_FALSE(c >= a);
ASSERT_FALSE(c >= 3);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(a >= c), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(c >= a), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(c >= 3), c10::Error);
ASSERT_TRUE(c >= 2);
ASSERT_TRUE(c >= 1);
ASSERT_TRUE(std::numeric_limits<int64_t>::max() >= c);
ASSERT_FALSE(std::numeric_limits<int64_t>::max() - 1 >= c);
ASSERT_FALSE(1 >= c);
// lt
ASSERT_FALSE(a < a);
ASSERT_FALSE(a < b);
ASSERT_FALSE(b < a);
ASSERT_FALSE(a < c);
ASSERT_FALSE(c < a);
ASSERT_FALSE(a < std::numeric_limits<int64_t>::max());
ASSERT_FALSE(3 < a);
ASSERT_FALSE(2 < a);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(a < c), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(c < a), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(3 < a), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(2 < a), c10::Error);
ASSERT_TRUE(1 < a);
// le
ASSERT_TRUE(a <= a);
ASSERT_TRUE(b <= a);
ASSERT_TRUE(a <= b);
ASSERT_FALSE(c <= a);
ASSERT_FALSE(a <= c);
ASSERT_FALSE(3 <= c);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(a <= c), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(c <= a), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(3 <= c), c10::Error);
ASSERT_TRUE(2 <= c);
ASSERT_TRUE(1 <= c);
ASSERT_TRUE(c <= std::numeric_limits<int64_t>::max());
ASSERT_FALSE(c <= std::numeric_limits<int64_t>::max() - 1);
ASSERT_FALSE(c <= 1);
// gt
ASSERT_FALSE(a > a);
ASSERT_FALSE(b > a);
ASSERT_FALSE(a > b);
ASSERT_FALSE(c > a);
ASSERT_FALSE(a > c);
ASSERT_FALSE(std::numeric_limits<int64_t>::max() > a);
ASSERT_FALSE(a > 3);
ASSERT_FALSE(a > 2);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(a > c), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(c > a), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(a > 3), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(a > 2), c10::Error);
ASSERT_TRUE(a > 1);
}
#endif

View File

@ -784,7 +784,8 @@ class TestSymNumberMagicMethods(TestCase):
j1 + 3
self.assertFalse(j1 == 3)
self.assertFalse(3 >= j2)
with self.assertRaisesRegex(RuntimeError, "indeterminate"):
self.assertFalse(3 >= j2)
self.assertIs(j1 == j1, True)
self.assertIs(j1 == j2, True)