Revert "Support SingletonSymNode mul with coefficient (#110369)"

This reverts commit eb8feb8ff8610d53d92773c2d7dce05c2196d672.

Reverted https://github.com/pytorch/pytorch/pull/110369 on behalf of https://github.com/PaliC due to bottom diff is causing a plethora of internal failures ([comment](https://github.com/pytorch/pytorch/pull/110369#issuecomment-1749802899))
This commit is contained in:
PyTorch MergeBot
2023-10-05 23:51:26 +00:00
parent 236afe73a2
commit 1c3fae46ee
12 changed files with 22 additions and 87 deletions

View File

@ -22,7 +22,6 @@ DEFINE_BINARY_OP(ge, le)
DEFINE_BINARY_OP(le, ge)
DEFINE_BINARY_OP(lt, gt)
DEFINE_BINARY_OP(gt, lt)
DEFINE_BINARY_OP(mul, mul)
#undef DEFINE_BINARY_OP

View File

@ -54,7 +54,6 @@ class C10_API ConstantSymNodeImpl : public SymNodeImpl {
c10::SymNode le(const c10::SymNode& other) override;
c10::SymNode lt(const c10::SymNode& other) override;
c10::SymNode gt(const c10::SymNode& other) override;
c10::SymNode mul(const c10::SymNode& other) override;
std::string str() override {
if constexpr (is_int_()) {
return std::to_string(std::get<int64_t>(value_));

View File

@ -8,15 +8,13 @@ namespace {
bool _eq(const char* op, c10::SymNodeImpl* lhs, c10::SymNodeImpl* rhs) {
TORCH_INTERNAL_ASSERT(lhs->singleton_int().has_value());
c10::optional<int64_t> c = rhs->singleton_int();
return (
c.has_value() && lhs->singleton_int() == *c &&
lhs->singleton_coeff() == rhs->singleton_coeff());
return c.has_value() && lhs->singleton_int() == *c;
}
bool _ge(const char* op, c10::SymNodeImpl* lhs, c10::SymNodeImpl* rhs) {
if (auto mb_si = lhs->singleton_int()) {
if (auto mb_si2 = rhs->singleton_int()) {
if (*mb_si == *mb_si2) {
return lhs->singleton_coeff() >= rhs->singleton_coeff();
return true;
}
TORCH_CHECK(false, "Singleton int ", op, ": Relation is indeterminate");
}
@ -64,13 +62,4 @@ c10::SymNode SingletonSymNodeImpl::le(const c10::SymNode& other) {
_ge("le", other.get(), this)));
}
c10::SymNode SingletonSymNodeImpl::mul(const c10::SymNode& other) {
if (auto mb_si = other->singleton_int()) {
TORCH_CHECK(false, "Singleton int cannot be multiplied by singleton int");
}
c10::optional<int64_t> c = other->constant_int();
TORCH_CHECK(c.has_value());
return SymNode(c10::make_intrusive<SingletonSymNodeImpl>(val_, coeff_ * *c));
}
} // namespace c10

View File

@ -20,20 +20,11 @@ namespace c10 {
// a proxy to evaluate equality. We also constrain the range of values for this
// as to enable inequality checks.
//
// We also support a positive integer scalar "coeff" that is used for computing
// strides. For example given, a [B, j0, D] tensor, it can be strided in two
// different ways: [D * j0, D, 1] and [j0, 1, sum(j0)]. The coeff is used to
// differentiate the two cases.
//
// During tracing the strides of the outputs need to be a function of the size
// and strides of the inputs so it is important that SingletonSymNode itself is
// able to express this.
class C10_API SingletonSymNodeImpl : public SymNodeImpl {
public:
// CAUTION: you should probably not be constructing these directly; please
// the higher-level API in python instead (TODO: actually introduce that).
explicit SingletonSymNodeImpl(int64_t val, int64_t coeff)
: val_(val), coeff_(coeff) {}
explicit SingletonSymNodeImpl(int64_t val) : val_(val) {}
bool bool_() override {
return false;
@ -76,10 +67,7 @@ class C10_API SingletonSymNodeImpl : public SymNodeImpl {
}
std::string str() override {
if (coeff_ == 1) {
return "j" + std::to_string(val_);
}
return std::to_string(coeff_) + "*j" + std::to_string(val_);
return "j" + std::to_string(val_);
}
// NOTE [ Inequalities with SingletonInt ]
@ -108,30 +96,17 @@ class C10_API SingletonSymNodeImpl : public SymNodeImpl {
// would mean that means that if we define the indeterminate j0 >= 3 to be
// False, the also indeterminate j0 < 3 will be evaluated to be True!
//
// [ Coefficient are assumed positive ]
//
// For the purpose of computing inequalities, we consider the coefficient of
// the SingletonInt to be a positive integer.
//
// Thus, no modificaitons are needed to the logic since
// j0 >= k implies coeff * j0 >= k
//
c10::SymNode eq(const c10::SymNode& other) override;
c10::SymNode ne(const c10::SymNode& other) override;
c10::SymNode ge(const c10::SymNode& other) override;
c10::SymNode gt(const c10::SymNode& other) override;
c10::SymNode lt(const c10::SymNode& other) override;
c10::SymNode le(const c10::SymNode& other) override;
c10::SymNode mul(const c10::SymNode& other) override;
c10::optional<int64_t> singleton_int() override {
return val_;
}
c10::optional<int64_t> singleton_coeff() override {
return coeff_;
}
bool is_symbolic() override {
return false;
}
@ -143,6 +118,7 @@ class C10_API SingletonSymNodeImpl : public SymNodeImpl {
DEFINE_BINARY_NOT_SUPPORTED(add)
DEFINE_BINARY_NOT_SUPPORTED(sub)
DEFINE_BINARY_NOT_SUPPORTED(mul)
DEFINE_BINARY_NOT_SUPPORTED(truediv)
DEFINE_BINARY_NOT_SUPPORTED(pow)
DEFINE_BINARY_NOT_SUPPORTED(floordiv)
@ -170,7 +146,6 @@ class C10_API SingletonSymNodeImpl : public SymNodeImpl {
private:
int64_t val_;
int64_t coeff_;
};
} // namespace c10

View File

@ -177,9 +177,6 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
virtual c10::optional<int64_t> singleton_int() {
return c10::nullopt;
}
virtual c10::optional<int64_t> singleton_coeff() {
return c10::nullopt;
}
virtual c10::optional<int64_t> constant_int() {
return c10::nullopt;
}

View File

@ -25,11 +25,11 @@ TEST(SymIntTest, CheckRange) {
TEST(SymIntTest, SingletonSymNode) {
auto a = c10::SymInt(
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1, 1)));
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1)));
auto b = c10::SymInt(
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1, 1)));
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1)));
auto c = c10::SymInt(
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(2, 1)));
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(2)));
auto d = c10::SymInt(3);
ASSERT_TRUE(a == a);
@ -100,23 +100,4 @@ TEST(SymIntTest, SingletonSymNode) {
EXPECT_THROW((void)(a > 2), c10::Error);
ASSERT_TRUE(a > 1);
}
TEST(SymIntTest, SingletonSymNodeWithFactor) {
auto a = c10::SymInt(
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1, 5)));
auto b = c10::SymInt(
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1, 10)));
// eq
ASSERT_FALSE(a == b);
ASSERT_FALSE(a >= b);
ASSERT_TRUE(b >= a);
ASSERT_TRUE(a <= b);
ASSERT_FALSE(b <= a);
// ne
ASSERT_TRUE(a != b);
// mul
ASSERT_TRUE(a * 2 == b);
ASSERT_TRUE(a * 3 >= b);
ASSERT_TRUE(a * 2 == 2 * a);
}
#endif

View File

@ -751,9 +751,9 @@ class TestSymNumberMagicMethods(TestCase):
hash(x)
# Singleton SymInt, constant SymBool, SymNode are hashable
j1 = torch._C._get_singleton_int(1, 1)
j1_copy = torch._C._get_singleton_int(1, 1)
j2 = torch._C._get_singleton_int(2, 1)
j1 = torch._C._get_singleton_int(1)
j1_copy = torch._C._get_singleton_int(1)
j2 = torch._C._get_singleton_int(2)
t = self.get_constant_bool(True)
t_copy = self.get_constant_bool(True)
f = self.get_constant_bool(False)
@ -773,9 +773,9 @@ class TestSymNumberMagicMethods(TestCase):
hash(m)
def test_non_symbolic_symnode(self):
j1 = torch._C._get_singleton_int(1, 1)
j2 = torch._C._get_singleton_int(1, 1)
j3 = torch._C._get_singleton_int(3, 1)
j1 = torch._C._get_singleton_int(1)
j2 = torch._C._get_singleton_int(1)
j3 = torch._C._get_singleton_int(3)
self.assertIsInstance(j1, torch.SymInt)
self.assertNotIsInstance(j1, int)

View File

@ -2851,7 +2851,7 @@ class TestNestedTensorSubclass(TestCase):
"directly calling torch.ops.aten.size"):
torch.ops.aten.size.default(nt)
singleton_int = torch.nested._internal.nested_tensor.get_tensor_id(_offsets, coeff=1)
singleton_int = torch.nested._internal.nested_tensor.get_tensor_id(_offsets)
self.assertEqual(nt.size(), (3, singleton_int, 3))
self.assertEqual(nt.shape, (3, singleton_int, 3))
self.assertEqual(nt.dim(), 3)

View File

@ -1514,7 +1514,7 @@ def _are_functorch_transforms_active() -> _bool: ...
# Define in torch/csrc/autograd/init.cpp
def _set_python_dispatcher(dispatcher: object) -> None: ...
def _get_singleton_int(id: _int, coeff: _int) -> SymInt: ...
def _get_singleton_int(id: _int) -> SymInt: ...
def _get_constant_bool_symnode(val: _bool) -> Any: ...

View File

@ -1283,11 +1283,6 @@ void initJITBindings(PyObject* module) {
"singleton_int",
[](const c10::SymNode& node) {
return node->singleton_int();
})
.def(
"singleton_coeff",
[](const c10::SymNode& node) {
return node->singleton_coeff();
});
// clang-format on

View File

@ -779,9 +779,9 @@ void initDispatchBindings(PyObject* module) {
include_set.has(c10::DispatchKey::FuncTorchDynamicLayerBackMode));
});
m.def("_get_singleton_int", [](int64_t data, int64_t coeff) {
return c10::SymInt(c10::SymNode(
c10::make_intrusive<c10::SingletonSymNodeImpl>(data, coeff)));
m.def("_get_singleton_int", [](int64_t data) {
return c10::SymInt(
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(data)));
});
m.def("_get_constant_bool_symnode", [](int64_t data) {

View File

@ -9,12 +9,12 @@ _tensor_id_counter = 0
_tensor_id_registry = WeakTensorKeyDictionary()
def get_tensor_id(tensor, *, coeff=1):
def get_tensor_id(tensor):
global _tensor_id_counter
if tensor not in _tensor_id_registry:
_tensor_id_registry[tensor] = _tensor_id_counter
_tensor_id_counter += 1
return torch._C._get_singleton_int(_tensor_id_registry[tensor], coeff)
return torch._C._get_singleton_int(_tensor_id_registry[tensor])
class NestedTensor(torch.Tensor):
@ -61,7 +61,7 @@ class NestedTensor(torch.Tensor):
# In a later PR, we'll need to accept an additional size argument
# to handle dynamic shapes.
ragged_dim = get_tensor_id(offsets, coeff=1)
ragged_dim = get_tensor_id(offsets)
D = values.shape[1]
B = offsets.shape[0] - 1
self._size = (B, ragged_dim, D)