mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[reland] Update singleton int to error when inequality relation is undefined (#110672)
reland of https://github.com/pytorch/pytorch/pull/110044 Pull Request resolved: https://github.com/pytorch/pytorch/pull/110672 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
576b80d23e
commit
69ea214cc2
@ -1,3 +1,65 @@
|
||||
#include <c10/core/SingletonSymNodeImpl.h>
|
||||
#include <c10/core/SymNodeImpl.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
namespace c10 {} // namespace c10
|
||||
namespace c10 {
|
||||
|
||||
namespace {
|
||||
bool _eq(const char* op, c10::SymNodeImpl* lhs, c10::SymNodeImpl* rhs) {
|
||||
TORCH_INTERNAL_ASSERT(lhs->singleton_int().has_value());
|
||||
c10::optional<int64_t> c = rhs->singleton_int();
|
||||
return c.has_value() && lhs->singleton_int() == *c;
|
||||
}
|
||||
bool _ge(const char* op, c10::SymNodeImpl* lhs, c10::SymNodeImpl* rhs) {
|
||||
if (auto mb_si = lhs->singleton_int()) {
|
||||
if (auto mb_si2 = rhs->singleton_int()) {
|
||||
if (*mb_si == *mb_si2) {
|
||||
return true;
|
||||
}
|
||||
TORCH_CHECK(false, "Singleton int ", op, ": Relation is indeterminate");
|
||||
}
|
||||
if (rhs->constant_int() && *rhs->constant_int() <= 2) {
|
||||
return true;
|
||||
}
|
||||
TORCH_CHECK(false, "Singleton int ", op, ": Relation is indeterminate");
|
||||
} else if (rhs->singleton_int()) {
|
||||
if (lhs->constant_int() && *lhs->constant_int() < 2) {
|
||||
return false;
|
||||
}
|
||||
TORCH_CHECK(false, "Singleton int ", op, ": Relation is indeterminate");
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(false, "expect at least one singleton");
|
||||
}
|
||||
} // namespace
|
||||
|
||||
c10::SymNode SingletonSymNodeImpl::eq(const c10::SymNode& other) {
|
||||
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
|
||||
_eq("eq", this, other.get())));
|
||||
}
|
||||
|
||||
c10::SymNode SingletonSymNodeImpl::ne(const c10::SymNode& other) {
|
||||
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
|
||||
!_eq("ne", this, other.get())));
|
||||
}
|
||||
|
||||
c10::SymNode SingletonSymNodeImpl::ge(const c10::SymNode& other) {
|
||||
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
|
||||
_ge("ge", this, other.get())));
|
||||
}
|
||||
|
||||
c10::SymNode SingletonSymNodeImpl::gt(const c10::SymNode& other) {
|
||||
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
|
||||
!_ge("gt", other.get(), this)));
|
||||
}
|
||||
|
||||
c10::SymNode SingletonSymNodeImpl::lt(const c10::SymNode& other) {
|
||||
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
|
||||
!_ge("lt", this, other.get())));
|
||||
}
|
||||
|
||||
c10::SymNode SingletonSymNodeImpl::le(const c10::SymNode& other) {
|
||||
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
|
||||
_ge("le", other.get(), this)));
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
@ -19,6 +19,7 @@ namespace c10 {
|
||||
// we associate each raggedness pattern with an integer "id" that can be used as
|
||||
// a proxy to evaluate equality. We also constrain the range of values for this
|
||||
// as to enable inequality checks.
|
||||
//
|
||||
class C10_API SingletonSymNodeImpl : public SymNodeImpl {
|
||||
public:
|
||||
// CAUTION: you should probably not be constructing these directly; please
|
||||
@ -69,55 +70,38 @@ class C10_API SingletonSymNodeImpl : public SymNodeImpl {
|
||||
return "j" + std::to_string(val_);
|
||||
}
|
||||
|
||||
c10::SymNode eq(const c10::SymNode& other) override {
|
||||
c10::optional<int64_t> c = other->singleton_int();
|
||||
bool ret = c.has_value() && val_ == *c;
|
||||
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(ret));
|
||||
}
|
||||
|
||||
c10::SymNode ne(const c10::SymNode& other) override {
|
||||
c10::optional<int64_t> c = other->singleton_int();
|
||||
bool ret = !c.has_value() || val_ != *c;
|
||||
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(ret));
|
||||
}
|
||||
|
||||
// It would be cool to have the ability to arbitrarily constrain the range of
|
||||
// values as we do for unbacked symints. For now a useful default
|
||||
// range seems to be [2, int64_t::max()] (1) since sizes are non-negative, and
|
||||
// (2) we need to get past 0/1 specialization checks.
|
||||
c10::SymNode ge(const c10::SymNode& other) override {
|
||||
if (auto mb_si = other->singleton_int()) {
|
||||
return SymNode(
|
||||
c10::make_intrusive<ConstantSymNodeImpl<bool>>(val_ == *mb_si));
|
||||
}
|
||||
c10::optional<int64_t> c = other->constant_int();
|
||||
TORCH_CHECK(c.has_value());
|
||||
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(*c <= 2));
|
||||
}
|
||||
|
||||
c10::SymNode gt(const c10::SymNode& other) override {
|
||||
if (auto mb_si = other->singleton_int()) {
|
||||
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(false));
|
||||
}
|
||||
c10::optional<int64_t> c = other->constant_int();
|
||||
TORCH_CHECK(c.has_value());
|
||||
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(*c < 2));
|
||||
}
|
||||
|
||||
c10::SymNode lt(const c10::SymNode& other) override {
|
||||
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(false));
|
||||
}
|
||||
|
||||
c10::SymNode le(const c10::SymNode& other) override {
|
||||
if (auto mb_si = other->singleton_int()) {
|
||||
return SymNode(
|
||||
c10::make_intrusive<ConstantSymNodeImpl<bool>>(val_ == *mb_si));
|
||||
}
|
||||
c10::optional<int64_t> c = other->constant_int();
|
||||
TORCH_CHECK(c.has_value());
|
||||
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
|
||||
*c >= std::numeric_limits<int64_t>::max()));
|
||||
}
|
||||
// NOTE [ Inequalities with SingletonInt ]
|
||||
//
|
||||
// The semantics of SingletonInt when it comes to relations is that it is
|
||||
// treated as integer known to be within a certain range,
|
||||
//
|
||||
// j0 \in [2, int64_t::max]
|
||||
//
|
||||
// allowing us to answer queries like j0 >= 1 (True), and j0 == 0 (False).
|
||||
// This is a useful default range for the raggedness pattern of a jagged
|
||||
// tensor (1) since sizes are non-negative, and (2) we need to get past 0/1
|
||||
// specialization checks.
|
||||
//
|
||||
// [ Indeterminate inequalities error out ]
|
||||
//
|
||||
// Given the semantic defined above, certain relations like j0 < 3 are thus
|
||||
// indeterminable. In our impl today, evaluating such relations error
|
||||
//
|
||||
// It may seem convenient to just define indeterminate relations to return
|
||||
// False, but the implementation we maintain in parallel using sympy does not
|
||||
// allow this.
|
||||
//
|
||||
// Sympy only allows overriding of Ge. The other relations (Lt, Gt, Le) are,
|
||||
// by consequence, all derived from Ge e.g., Lt(a, b) := !Ge(a, b). This
|
||||
// would mean that means that if we define the indeterminate j0 >= 3 to be
|
||||
// False, the also indeterminate j0 < 3 will be evaluated to be True!
|
||||
//
|
||||
c10::SymNode eq(const c10::SymNode& other) override;
|
||||
c10::SymNode ne(const c10::SymNode& other) override;
|
||||
c10::SymNode ge(const c10::SymNode& other) override;
|
||||
c10::SymNode gt(const c10::SymNode& other) override;
|
||||
c10::SymNode lt(const c10::SymNode& other) override;
|
||||
c10::SymNode le(const c10::SymNode& other) override;
|
||||
|
||||
c10::optional<int64_t> singleton_int() override {
|
||||
return val_;
|
||||
|
@ -48,48 +48,56 @@ TEST(SymIntTest, SingletonSymNode) {
|
||||
ASSERT_TRUE(a >= a);
|
||||
ASSERT_TRUE(a >= b);
|
||||
ASSERT_TRUE(b >= a);
|
||||
ASSERT_FALSE(a >= c);
|
||||
ASSERT_FALSE(c >= a);
|
||||
ASSERT_FALSE(c >= 3);
|
||||
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
|
||||
EXPECT_THROW((void)(a >= c), c10::Error);
|
||||
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
|
||||
EXPECT_THROW((void)(c >= a), c10::Error);
|
||||
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
|
||||
EXPECT_THROW((void)(c >= 3), c10::Error);
|
||||
ASSERT_TRUE(c >= 2);
|
||||
ASSERT_TRUE(c >= 1);
|
||||
ASSERT_TRUE(std::numeric_limits<int64_t>::max() >= c);
|
||||
ASSERT_FALSE(std::numeric_limits<int64_t>::max() - 1 >= c);
|
||||
ASSERT_FALSE(1 >= c);
|
||||
|
||||
// lt
|
||||
ASSERT_FALSE(a < a);
|
||||
ASSERT_FALSE(a < b);
|
||||
ASSERT_FALSE(b < a);
|
||||
ASSERT_FALSE(a < c);
|
||||
ASSERT_FALSE(c < a);
|
||||
ASSERT_FALSE(a < std::numeric_limits<int64_t>::max());
|
||||
ASSERT_FALSE(3 < a);
|
||||
ASSERT_FALSE(2 < a);
|
||||
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
|
||||
EXPECT_THROW((void)(a < c), c10::Error);
|
||||
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
|
||||
EXPECT_THROW((void)(c < a), c10::Error);
|
||||
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
|
||||
EXPECT_THROW((void)(3 < a), c10::Error);
|
||||
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
|
||||
EXPECT_THROW((void)(2 < a), c10::Error);
|
||||
ASSERT_TRUE(1 < a);
|
||||
|
||||
// le
|
||||
ASSERT_TRUE(a <= a);
|
||||
ASSERT_TRUE(b <= a);
|
||||
ASSERT_TRUE(a <= b);
|
||||
ASSERT_FALSE(c <= a);
|
||||
ASSERT_FALSE(a <= c);
|
||||
ASSERT_FALSE(3 <= c);
|
||||
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
|
||||
EXPECT_THROW((void)(a <= c), c10::Error);
|
||||
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
|
||||
EXPECT_THROW((void)(c <= a), c10::Error);
|
||||
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
|
||||
EXPECT_THROW((void)(3 <= c), c10::Error);
|
||||
ASSERT_TRUE(2 <= c);
|
||||
ASSERT_TRUE(1 <= c);
|
||||
ASSERT_TRUE(c <= std::numeric_limits<int64_t>::max());
|
||||
ASSERT_FALSE(c <= std::numeric_limits<int64_t>::max() - 1);
|
||||
ASSERT_FALSE(c <= 1);
|
||||
|
||||
// gt
|
||||
ASSERT_FALSE(a > a);
|
||||
ASSERT_FALSE(b > a);
|
||||
ASSERT_FALSE(a > b);
|
||||
ASSERT_FALSE(c > a);
|
||||
ASSERT_FALSE(a > c);
|
||||
ASSERT_FALSE(std::numeric_limits<int64_t>::max() > a);
|
||||
ASSERT_FALSE(a > 3);
|
||||
ASSERT_FALSE(a > 2);
|
||||
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
|
||||
EXPECT_THROW((void)(a > c), c10::Error);
|
||||
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
|
||||
EXPECT_THROW((void)(c > a), c10::Error);
|
||||
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
|
||||
EXPECT_THROW((void)(a > 3), c10::Error);
|
||||
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
|
||||
EXPECT_THROW((void)(a > 2), c10::Error);
|
||||
ASSERT_TRUE(a > 1);
|
||||
}
|
||||
#endif
|
||||
|
@ -784,7 +784,8 @@ class TestSymNumberMagicMethods(TestCase):
|
||||
j1 + 3
|
||||
|
||||
self.assertFalse(j1 == 3)
|
||||
self.assertFalse(3 >= j2)
|
||||
with self.assertRaisesRegex(RuntimeError, "indeterminate"):
|
||||
self.assertFalse(3 >= j2)
|
||||
|
||||
self.assertIs(j1 == j1, True)
|
||||
self.assertIs(j1 == j2, True)
|
||||
|
Reference in New Issue
Block a user