mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	[reland] Update singleton int to error when inequality relation is undefined (#110672)
reland of https://github.com/pytorch/pytorch/pull/110044 Pull Request resolved: https://github.com/pytorch/pytorch/pull/110672 Approved by: https://github.com/ezyang
This commit is contained in:
		
				
					committed by
					
						
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							576b80d23e
						
					
				
				
					commit
					69ea214cc2
				
			@ -1,3 +1,65 @@
 | 
			
		||||
#include <c10/core/SingletonSymNodeImpl.h>
 | 
			
		||||
#include <c10/core/SymNodeImpl.h>
 | 
			
		||||
#include <c10/util/Exception.h>
 | 
			
		||||
 | 
			
		||||
namespace c10 {} // namespace c10
 | 
			
		||||
namespace c10 {
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
bool _eq(const char* op, c10::SymNodeImpl* lhs, c10::SymNodeImpl* rhs) {
 | 
			
		||||
  TORCH_INTERNAL_ASSERT(lhs->singleton_int().has_value());
 | 
			
		||||
  c10::optional<int64_t> c = rhs->singleton_int();
 | 
			
		||||
  return c.has_value() && lhs->singleton_int() == *c;
 | 
			
		||||
}
 | 
			
		||||
bool _ge(const char* op, c10::SymNodeImpl* lhs, c10::SymNodeImpl* rhs) {
 | 
			
		||||
  if (auto mb_si = lhs->singleton_int()) {
 | 
			
		||||
    if (auto mb_si2 = rhs->singleton_int()) {
 | 
			
		||||
      if (*mb_si == *mb_si2) {
 | 
			
		||||
        return true;
 | 
			
		||||
      }
 | 
			
		||||
      TORCH_CHECK(false, "Singleton int ", op, ": Relation is indeterminate");
 | 
			
		||||
    }
 | 
			
		||||
    if (rhs->constant_int() && *rhs->constant_int() <= 2) {
 | 
			
		||||
      return true;
 | 
			
		||||
    }
 | 
			
		||||
    TORCH_CHECK(false, "Singleton int ", op, ": Relation is indeterminate");
 | 
			
		||||
  } else if (rhs->singleton_int()) {
 | 
			
		||||
    if (lhs->constant_int() && *lhs->constant_int() < 2) {
 | 
			
		||||
      return false;
 | 
			
		||||
    }
 | 
			
		||||
    TORCH_CHECK(false, "Singleton int ", op, ": Relation is indeterminate");
 | 
			
		||||
  }
 | 
			
		||||
  TORCH_INTERNAL_ASSERT(false, "expect at least one singleton");
 | 
			
		||||
}
 | 
			
		||||
} // namespace
 | 
			
		||||
 | 
			
		||||
c10::SymNode SingletonSymNodeImpl::eq(const c10::SymNode& other) {
 | 
			
		||||
  return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
 | 
			
		||||
      _eq("eq", this, other.get())));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
c10::SymNode SingletonSymNodeImpl::ne(const c10::SymNode& other) {
 | 
			
		||||
  return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
 | 
			
		||||
      !_eq("ne", this, other.get())));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
c10::SymNode SingletonSymNodeImpl::ge(const c10::SymNode& other) {
 | 
			
		||||
  return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
 | 
			
		||||
      _ge("ge", this, other.get())));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
c10::SymNode SingletonSymNodeImpl::gt(const c10::SymNode& other) {
 | 
			
		||||
  return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
 | 
			
		||||
      !_ge("gt", other.get(), this)));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
c10::SymNode SingletonSymNodeImpl::lt(const c10::SymNode& other) {
 | 
			
		||||
  return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
 | 
			
		||||
      !_ge("lt", this, other.get())));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
c10::SymNode SingletonSymNodeImpl::le(const c10::SymNode& other) {
 | 
			
		||||
  return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
 | 
			
		||||
      _ge("le", other.get(), this)));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace c10
 | 
			
		||||
 | 
			
		||||
@ -19,6 +19,7 @@ namespace c10 {
 | 
			
		||||
// we associate each raggedness pattern with an integer "id" that can be used as
 | 
			
		||||
// a proxy to evaluate equality. We also constrain the range of values for this
 | 
			
		||||
// as to enable inequality checks.
 | 
			
		||||
//
 | 
			
		||||
class C10_API SingletonSymNodeImpl : public SymNodeImpl {
 | 
			
		||||
 public:
 | 
			
		||||
  // CAUTION: you should probably not be constructing these directly; please
 | 
			
		||||
@ -69,55 +70,38 @@ class C10_API SingletonSymNodeImpl : public SymNodeImpl {
 | 
			
		||||
    return "j" + std::to_string(val_);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  c10::SymNode eq(const c10::SymNode& other) override {
 | 
			
		||||
    c10::optional<int64_t> c = other->singleton_int();
 | 
			
		||||
    bool ret = c.has_value() && val_ == *c;
 | 
			
		||||
    return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(ret));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  c10::SymNode ne(const c10::SymNode& other) override {
 | 
			
		||||
    c10::optional<int64_t> c = other->singleton_int();
 | 
			
		||||
    bool ret = !c.has_value() || val_ != *c;
 | 
			
		||||
    return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(ret));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // It would be cool to have the ability to arbitrarily constrain the range of
 | 
			
		||||
  // values as we do for unbacked symints. For now a useful default
 | 
			
		||||
  // range seems to be [2, int64_t::max()] (1) since sizes are non-negative, and
 | 
			
		||||
  // (2) we need to get past 0/1 specialization checks.
 | 
			
		||||
  c10::SymNode ge(const c10::SymNode& other) override {
 | 
			
		||||
    if (auto mb_si = other->singleton_int()) {
 | 
			
		||||
      return SymNode(
 | 
			
		||||
          c10::make_intrusive<ConstantSymNodeImpl<bool>>(val_ == *mb_si));
 | 
			
		||||
    }
 | 
			
		||||
    c10::optional<int64_t> c = other->constant_int();
 | 
			
		||||
    TORCH_CHECK(c.has_value());
 | 
			
		||||
    return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(*c <= 2));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  c10::SymNode gt(const c10::SymNode& other) override {
 | 
			
		||||
    if (auto mb_si = other->singleton_int()) {
 | 
			
		||||
      return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(false));
 | 
			
		||||
    }
 | 
			
		||||
    c10::optional<int64_t> c = other->constant_int();
 | 
			
		||||
    TORCH_CHECK(c.has_value());
 | 
			
		||||
    return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(*c < 2));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  c10::SymNode lt(const c10::SymNode& other) override {
 | 
			
		||||
    return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(false));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  c10::SymNode le(const c10::SymNode& other) override {
 | 
			
		||||
    if (auto mb_si = other->singleton_int()) {
 | 
			
		||||
      return SymNode(
 | 
			
		||||
          c10::make_intrusive<ConstantSymNodeImpl<bool>>(val_ == *mb_si));
 | 
			
		||||
    }
 | 
			
		||||
    c10::optional<int64_t> c = other->constant_int();
 | 
			
		||||
    TORCH_CHECK(c.has_value());
 | 
			
		||||
    return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
 | 
			
		||||
        *c >= std::numeric_limits<int64_t>::max()));
 | 
			
		||||
  }
 | 
			
		||||
  // NOTE [ Inequalities with SingletonInt ]
 | 
			
		||||
  //
 | 
			
		||||
  // The semantics of SingletonInt when it comes to relations is that it is
 | 
			
		||||
  // treated as integer known to be within a certain range,
 | 
			
		||||
  //
 | 
			
		||||
  //     j0 \in [2, int64_t::max]
 | 
			
		||||
  //
 | 
			
		||||
  // allowing us to answer queries like j0 >= 1 (True), and j0 == 0 (False).
 | 
			
		||||
  // This is a useful default range for the raggedness pattern of a jagged
 | 
			
		||||
  // tensor (1) since sizes are non-negative, and (2) we need to get past 0/1
 | 
			
		||||
  // specialization checks.
 | 
			
		||||
  //
 | 
			
		||||
  // [ Indeterminate inequalities error out ]
 | 
			
		||||
  //
 | 
			
		||||
  // Given the semantic defined above, certain relations like j0 < 3 are thus
 | 
			
		||||
  // indeterminable. In our impl today, evaluating such relations error
 | 
			
		||||
  //
 | 
			
		||||
  // It may seem convenient to just define indeterminate relations to return
 | 
			
		||||
  // False, but the implementation we maintain in parallel using sympy does not
 | 
			
		||||
  // allow this.
 | 
			
		||||
  //
 | 
			
		||||
  // Sympy only allows overriding of Ge. The other relations (Lt, Gt, Le) are,
 | 
			
		||||
  // by consequence, all derived from Ge e.g., Lt(a, b) := !Ge(a, b). This
 | 
			
		||||
  // would mean that means that if we define the indeterminate j0 >= 3 to be
 | 
			
		||||
  // False, the also indeterminate j0 < 3 will be evaluated to be True!
 | 
			
		||||
  //
 | 
			
		||||
  c10::SymNode eq(const c10::SymNode& other) override;
 | 
			
		||||
  c10::SymNode ne(const c10::SymNode& other) override;
 | 
			
		||||
  c10::SymNode ge(const c10::SymNode& other) override;
 | 
			
		||||
  c10::SymNode gt(const c10::SymNode& other) override;
 | 
			
		||||
  c10::SymNode lt(const c10::SymNode& other) override;
 | 
			
		||||
  c10::SymNode le(const c10::SymNode& other) override;
 | 
			
		||||
 | 
			
		||||
  c10::optional<int64_t> singleton_int() override {
 | 
			
		||||
    return val_;
 | 
			
		||||
 | 
			
		||||
@ -48,48 +48,56 @@ TEST(SymIntTest, SingletonSymNode) {
 | 
			
		||||
  ASSERT_TRUE(a >= a);
 | 
			
		||||
  ASSERT_TRUE(a >= b);
 | 
			
		||||
  ASSERT_TRUE(b >= a);
 | 
			
		||||
  ASSERT_FALSE(a >= c);
 | 
			
		||||
  ASSERT_FALSE(c >= a);
 | 
			
		||||
  ASSERT_FALSE(c >= 3);
 | 
			
		||||
  // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
 | 
			
		||||
  EXPECT_THROW((void)(a >= c), c10::Error);
 | 
			
		||||
  // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
 | 
			
		||||
  EXPECT_THROW((void)(c >= a), c10::Error);
 | 
			
		||||
  // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
 | 
			
		||||
  EXPECT_THROW((void)(c >= 3), c10::Error);
 | 
			
		||||
  ASSERT_TRUE(c >= 2);
 | 
			
		||||
  ASSERT_TRUE(c >= 1);
 | 
			
		||||
  ASSERT_TRUE(std::numeric_limits<int64_t>::max() >= c);
 | 
			
		||||
  ASSERT_FALSE(std::numeric_limits<int64_t>::max() - 1 >= c);
 | 
			
		||||
  ASSERT_FALSE(1 >= c);
 | 
			
		||||
 | 
			
		||||
  // lt
 | 
			
		||||
  ASSERT_FALSE(a < a);
 | 
			
		||||
  ASSERT_FALSE(a < b);
 | 
			
		||||
  ASSERT_FALSE(b < a);
 | 
			
		||||
  ASSERT_FALSE(a < c);
 | 
			
		||||
  ASSERT_FALSE(c < a);
 | 
			
		||||
  ASSERT_FALSE(a < std::numeric_limits<int64_t>::max());
 | 
			
		||||
  ASSERT_FALSE(3 < a);
 | 
			
		||||
  ASSERT_FALSE(2 < a);
 | 
			
		||||
  // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
 | 
			
		||||
  EXPECT_THROW((void)(a < c), c10::Error);
 | 
			
		||||
  // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
 | 
			
		||||
  EXPECT_THROW((void)(c < a), c10::Error);
 | 
			
		||||
  // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
 | 
			
		||||
  EXPECT_THROW((void)(3 < a), c10::Error);
 | 
			
		||||
  // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
 | 
			
		||||
  EXPECT_THROW((void)(2 < a), c10::Error);
 | 
			
		||||
  ASSERT_TRUE(1 < a);
 | 
			
		||||
 | 
			
		||||
  // le
 | 
			
		||||
  ASSERT_TRUE(a <= a);
 | 
			
		||||
  ASSERT_TRUE(b <= a);
 | 
			
		||||
  ASSERT_TRUE(a <= b);
 | 
			
		||||
  ASSERT_FALSE(c <= a);
 | 
			
		||||
  ASSERT_FALSE(a <= c);
 | 
			
		||||
  ASSERT_FALSE(3 <= c);
 | 
			
		||||
  // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
 | 
			
		||||
  EXPECT_THROW((void)(a <= c), c10::Error);
 | 
			
		||||
  // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
 | 
			
		||||
  EXPECT_THROW((void)(c <= a), c10::Error);
 | 
			
		||||
  // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
 | 
			
		||||
  EXPECT_THROW((void)(3 <= c), c10::Error);
 | 
			
		||||
  ASSERT_TRUE(2 <= c);
 | 
			
		||||
  ASSERT_TRUE(1 <= c);
 | 
			
		||||
  ASSERT_TRUE(c <= std::numeric_limits<int64_t>::max());
 | 
			
		||||
  ASSERT_FALSE(c <= std::numeric_limits<int64_t>::max() - 1);
 | 
			
		||||
  ASSERT_FALSE(c <= 1);
 | 
			
		||||
 | 
			
		||||
  // gt
 | 
			
		||||
  ASSERT_FALSE(a > a);
 | 
			
		||||
  ASSERT_FALSE(b > a);
 | 
			
		||||
  ASSERT_FALSE(a > b);
 | 
			
		||||
  ASSERT_FALSE(c > a);
 | 
			
		||||
  ASSERT_FALSE(a > c);
 | 
			
		||||
  ASSERT_FALSE(std::numeric_limits<int64_t>::max() > a);
 | 
			
		||||
  ASSERT_FALSE(a > 3);
 | 
			
		||||
  ASSERT_FALSE(a > 2);
 | 
			
		||||
  // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
 | 
			
		||||
  EXPECT_THROW((void)(a > c), c10::Error);
 | 
			
		||||
  // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
 | 
			
		||||
  EXPECT_THROW((void)(c > a), c10::Error);
 | 
			
		||||
  // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
 | 
			
		||||
  EXPECT_THROW((void)(a > 3), c10::Error);
 | 
			
		||||
  // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
 | 
			
		||||
  EXPECT_THROW((void)(a > 2), c10::Error);
 | 
			
		||||
  ASSERT_TRUE(a > 1);
 | 
			
		||||
}
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
@ -784,7 +784,8 @@ class TestSymNumberMagicMethods(TestCase):
 | 
			
		||||
            j1 + 3
 | 
			
		||||
 | 
			
		||||
        self.assertFalse(j1 == 3)
 | 
			
		||||
        self.assertFalse(3 >= j2)
 | 
			
		||||
        with self.assertRaisesRegex(RuntimeError, "indeterminate"):
 | 
			
		||||
            self.assertFalse(3 >= j2)
 | 
			
		||||
 | 
			
		||||
        self.assertIs(j1 == j1, True)
 | 
			
		||||
        self.assertIs(j1 == j2, True)
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user