mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Support SingletonSymNode mul with coefficient (#110369)"
This reverts commit eb8feb8ff8610d53d92773c2d7dce05c2196d672. Reverted https://github.com/pytorch/pytorch/pull/110369 on behalf of https://github.com/PaliC due to bottom diff is causing a plethora of internal failures ([comment](https://github.com/pytorch/pytorch/pull/110369#issuecomment-1749802899))
This commit is contained in:
@ -22,7 +22,6 @@ DEFINE_BINARY_OP(ge, le)
|
||||
DEFINE_BINARY_OP(le, ge)
|
||||
DEFINE_BINARY_OP(lt, gt)
|
||||
DEFINE_BINARY_OP(gt, lt)
|
||||
DEFINE_BINARY_OP(mul, mul)
|
||||
|
||||
#undef DEFINE_BINARY_OP
|
||||
|
||||
|
||||
@ -54,7 +54,6 @@ class C10_API ConstantSymNodeImpl : public SymNodeImpl {
|
||||
c10::SymNode le(const c10::SymNode& other) override;
|
||||
c10::SymNode lt(const c10::SymNode& other) override;
|
||||
c10::SymNode gt(const c10::SymNode& other) override;
|
||||
c10::SymNode mul(const c10::SymNode& other) override;
|
||||
std::string str() override {
|
||||
if constexpr (is_int_()) {
|
||||
return std::to_string(std::get<int64_t>(value_));
|
||||
|
||||
@ -8,15 +8,13 @@ namespace {
|
||||
bool _eq(const char* op, c10::SymNodeImpl* lhs, c10::SymNodeImpl* rhs) {
|
||||
TORCH_INTERNAL_ASSERT(lhs->singleton_int().has_value());
|
||||
c10::optional<int64_t> c = rhs->singleton_int();
|
||||
return (
|
||||
c.has_value() && lhs->singleton_int() == *c &&
|
||||
lhs->singleton_coeff() == rhs->singleton_coeff());
|
||||
return c.has_value() && lhs->singleton_int() == *c;
|
||||
}
|
||||
bool _ge(const char* op, c10::SymNodeImpl* lhs, c10::SymNodeImpl* rhs) {
|
||||
if (auto mb_si = lhs->singleton_int()) {
|
||||
if (auto mb_si2 = rhs->singleton_int()) {
|
||||
if (*mb_si == *mb_si2) {
|
||||
return lhs->singleton_coeff() >= rhs->singleton_coeff();
|
||||
return true;
|
||||
}
|
||||
TORCH_CHECK(false, "Singleton int ", op, ": Relation is indeterminate");
|
||||
}
|
||||
@ -64,13 +62,4 @@ c10::SymNode SingletonSymNodeImpl::le(const c10::SymNode& other) {
|
||||
_ge("le", other.get(), this)));
|
||||
}
|
||||
|
||||
c10::SymNode SingletonSymNodeImpl::mul(const c10::SymNode& other) {
|
||||
if (auto mb_si = other->singleton_int()) {
|
||||
TORCH_CHECK(false, "Singleton int cannot be multiplied by singleton int");
|
||||
}
|
||||
c10::optional<int64_t> c = other->constant_int();
|
||||
TORCH_CHECK(c.has_value());
|
||||
return SymNode(c10::make_intrusive<SingletonSymNodeImpl>(val_, coeff_ * *c));
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
@ -20,20 +20,11 @@ namespace c10 {
|
||||
// a proxy to evaluate equality. We also constrain the range of values for this
|
||||
// as to enable inequality checks.
|
||||
//
|
||||
// We also support a positive integer scalar "coeff" that is used for computing
|
||||
// strides. For example given, a [B, j0, D] tensor, it can be strided in two
|
||||
// different ways: [D * j0, D, 1] and [j0, 1, sum(j0)]. The coeff is used to
|
||||
// differentiate the two cases.
|
||||
//
|
||||
// During tracing the strides of the outputs need to be a function of the size
|
||||
// and strides of the inputs so it is important that SingletonSymNode itself is
|
||||
// able to express this.
|
||||
class C10_API SingletonSymNodeImpl : public SymNodeImpl {
|
||||
public:
|
||||
// CAUTION: you should probably not be constructing these directly; please
|
||||
// the higher-level API in python instead (TODO: actually introduce that).
|
||||
explicit SingletonSymNodeImpl(int64_t val, int64_t coeff)
|
||||
: val_(val), coeff_(coeff) {}
|
||||
explicit SingletonSymNodeImpl(int64_t val) : val_(val) {}
|
||||
|
||||
bool bool_() override {
|
||||
return false;
|
||||
@ -76,10 +67,7 @@ class C10_API SingletonSymNodeImpl : public SymNodeImpl {
|
||||
}
|
||||
|
||||
std::string str() override {
|
||||
if (coeff_ == 1) {
|
||||
return "j" + std::to_string(val_);
|
||||
}
|
||||
return std::to_string(coeff_) + "*j" + std::to_string(val_);
|
||||
return "j" + std::to_string(val_);
|
||||
}
|
||||
|
||||
// NOTE [ Inequalities with SingletonInt ]
|
||||
@ -108,30 +96,17 @@ class C10_API SingletonSymNodeImpl : public SymNodeImpl {
|
||||
// would mean that means that if we define the indeterminate j0 >= 3 to be
|
||||
// False, the also indeterminate j0 < 3 will be evaluated to be True!
|
||||
//
|
||||
// [ Coefficient are assumed positive ]
|
||||
//
|
||||
// For the purpose of computing inequalities, we consider the coefficient of
|
||||
// the SingletonInt to be a positive integer.
|
||||
//
|
||||
// Thus, no modificaitons are needed to the logic since
|
||||
// j0 >= k implies coeff * j0 >= k
|
||||
//
|
||||
c10::SymNode eq(const c10::SymNode& other) override;
|
||||
c10::SymNode ne(const c10::SymNode& other) override;
|
||||
c10::SymNode ge(const c10::SymNode& other) override;
|
||||
c10::SymNode gt(const c10::SymNode& other) override;
|
||||
c10::SymNode lt(const c10::SymNode& other) override;
|
||||
c10::SymNode le(const c10::SymNode& other) override;
|
||||
c10::SymNode mul(const c10::SymNode& other) override;
|
||||
|
||||
c10::optional<int64_t> singleton_int() override {
|
||||
return val_;
|
||||
}
|
||||
|
||||
c10::optional<int64_t> singleton_coeff() override {
|
||||
return coeff_;
|
||||
}
|
||||
|
||||
bool is_symbolic() override {
|
||||
return false;
|
||||
}
|
||||
@ -143,6 +118,7 @@ class C10_API SingletonSymNodeImpl : public SymNodeImpl {
|
||||
|
||||
DEFINE_BINARY_NOT_SUPPORTED(add)
|
||||
DEFINE_BINARY_NOT_SUPPORTED(sub)
|
||||
DEFINE_BINARY_NOT_SUPPORTED(mul)
|
||||
DEFINE_BINARY_NOT_SUPPORTED(truediv)
|
||||
DEFINE_BINARY_NOT_SUPPORTED(pow)
|
||||
DEFINE_BINARY_NOT_SUPPORTED(floordiv)
|
||||
@ -170,7 +146,6 @@ class C10_API SingletonSymNodeImpl : public SymNodeImpl {
|
||||
|
||||
private:
|
||||
int64_t val_;
|
||||
int64_t coeff_;
|
||||
};
|
||||
|
||||
} // namespace c10
|
||||
|
||||
@ -177,9 +177,6 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
|
||||
virtual c10::optional<int64_t> singleton_int() {
|
||||
return c10::nullopt;
|
||||
}
|
||||
virtual c10::optional<int64_t> singleton_coeff() {
|
||||
return c10::nullopt;
|
||||
}
|
||||
virtual c10::optional<int64_t> constant_int() {
|
||||
return c10::nullopt;
|
||||
}
|
||||
|
||||
@ -25,11 +25,11 @@ TEST(SymIntTest, CheckRange) {
|
||||
|
||||
TEST(SymIntTest, SingletonSymNode) {
|
||||
auto a = c10::SymInt(
|
||||
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1, 1)));
|
||||
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1)));
|
||||
auto b = c10::SymInt(
|
||||
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1, 1)));
|
||||
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1)));
|
||||
auto c = c10::SymInt(
|
||||
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(2, 1)));
|
||||
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(2)));
|
||||
auto d = c10::SymInt(3);
|
||||
|
||||
ASSERT_TRUE(a == a);
|
||||
@ -100,23 +100,4 @@ TEST(SymIntTest, SingletonSymNode) {
|
||||
EXPECT_THROW((void)(a > 2), c10::Error);
|
||||
ASSERT_TRUE(a > 1);
|
||||
}
|
||||
|
||||
TEST(SymIntTest, SingletonSymNodeWithFactor) {
|
||||
auto a = c10::SymInt(
|
||||
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1, 5)));
|
||||
auto b = c10::SymInt(
|
||||
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1, 10)));
|
||||
// eq
|
||||
ASSERT_FALSE(a == b);
|
||||
ASSERT_FALSE(a >= b);
|
||||
ASSERT_TRUE(b >= a);
|
||||
ASSERT_TRUE(a <= b);
|
||||
ASSERT_FALSE(b <= a);
|
||||
// ne
|
||||
ASSERT_TRUE(a != b);
|
||||
// mul
|
||||
ASSERT_TRUE(a * 2 == b);
|
||||
ASSERT_TRUE(a * 3 >= b);
|
||||
ASSERT_TRUE(a * 2 == 2 * a);
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -751,9 +751,9 @@ class TestSymNumberMagicMethods(TestCase):
|
||||
hash(x)
|
||||
|
||||
# Singleton SymInt, constant SymBool, SymNode are hashable
|
||||
j1 = torch._C._get_singleton_int(1, 1)
|
||||
j1_copy = torch._C._get_singleton_int(1, 1)
|
||||
j2 = torch._C._get_singleton_int(2, 1)
|
||||
j1 = torch._C._get_singleton_int(1)
|
||||
j1_copy = torch._C._get_singleton_int(1)
|
||||
j2 = torch._C._get_singleton_int(2)
|
||||
t = self.get_constant_bool(True)
|
||||
t_copy = self.get_constant_bool(True)
|
||||
f = self.get_constant_bool(False)
|
||||
@ -773,9 +773,9 @@ class TestSymNumberMagicMethods(TestCase):
|
||||
hash(m)
|
||||
|
||||
def test_non_symbolic_symnode(self):
|
||||
j1 = torch._C._get_singleton_int(1, 1)
|
||||
j2 = torch._C._get_singleton_int(1, 1)
|
||||
j3 = torch._C._get_singleton_int(3, 1)
|
||||
j1 = torch._C._get_singleton_int(1)
|
||||
j2 = torch._C._get_singleton_int(1)
|
||||
j3 = torch._C._get_singleton_int(3)
|
||||
|
||||
self.assertIsInstance(j1, torch.SymInt)
|
||||
self.assertNotIsInstance(j1, int)
|
||||
|
||||
@ -2851,7 +2851,7 @@ class TestNestedTensorSubclass(TestCase):
|
||||
"directly calling torch.ops.aten.size"):
|
||||
torch.ops.aten.size.default(nt)
|
||||
|
||||
singleton_int = torch.nested._internal.nested_tensor.get_tensor_id(_offsets, coeff=1)
|
||||
singleton_int = torch.nested._internal.nested_tensor.get_tensor_id(_offsets)
|
||||
self.assertEqual(nt.size(), (3, singleton_int, 3))
|
||||
self.assertEqual(nt.shape, (3, singleton_int, 3))
|
||||
self.assertEqual(nt.dim(), 3)
|
||||
|
||||
@ -1514,7 +1514,7 @@ def _are_functorch_transforms_active() -> _bool: ...
|
||||
# Define in torch/csrc/autograd/init.cpp
|
||||
def _set_python_dispatcher(dispatcher: object) -> None: ...
|
||||
|
||||
def _get_singleton_int(id: _int, coeff: _int) -> SymInt: ...
|
||||
def _get_singleton_int(id: _int) -> SymInt: ...
|
||||
|
||||
def _get_constant_bool_symnode(val: _bool) -> Any: ...
|
||||
|
||||
|
||||
@ -1283,11 +1283,6 @@ void initJITBindings(PyObject* module) {
|
||||
"singleton_int",
|
||||
[](const c10::SymNode& node) {
|
||||
return node->singleton_int();
|
||||
})
|
||||
.def(
|
||||
"singleton_coeff",
|
||||
[](const c10::SymNode& node) {
|
||||
return node->singleton_coeff();
|
||||
});
|
||||
|
||||
// clang-format on
|
||||
|
||||
@ -779,9 +779,9 @@ void initDispatchBindings(PyObject* module) {
|
||||
include_set.has(c10::DispatchKey::FuncTorchDynamicLayerBackMode));
|
||||
});
|
||||
|
||||
m.def("_get_singleton_int", [](int64_t data, int64_t coeff) {
|
||||
return c10::SymInt(c10::SymNode(
|
||||
c10::make_intrusive<c10::SingletonSymNodeImpl>(data, coeff)));
|
||||
m.def("_get_singleton_int", [](int64_t data) {
|
||||
return c10::SymInt(
|
||||
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(data)));
|
||||
});
|
||||
|
||||
m.def("_get_constant_bool_symnode", [](int64_t data) {
|
||||
|
||||
@ -9,12 +9,12 @@ _tensor_id_counter = 0
|
||||
_tensor_id_registry = WeakTensorKeyDictionary()
|
||||
|
||||
|
||||
def get_tensor_id(tensor, *, coeff=1):
|
||||
def get_tensor_id(tensor):
|
||||
global _tensor_id_counter
|
||||
if tensor not in _tensor_id_registry:
|
||||
_tensor_id_registry[tensor] = _tensor_id_counter
|
||||
_tensor_id_counter += 1
|
||||
return torch._C._get_singleton_int(_tensor_id_registry[tensor], coeff)
|
||||
return torch._C._get_singleton_int(_tensor_id_registry[tensor])
|
||||
|
||||
|
||||
class NestedTensor(torch.Tensor):
|
||||
@ -61,7 +61,7 @@ class NestedTensor(torch.Tensor):
|
||||
|
||||
# In a later PR, we'll need to accept an additional size argument
|
||||
# to handle dynamic shapes.
|
||||
ragged_dim = get_tensor_id(offsets, coeff=1)
|
||||
ragged_dim = get_tensor_id(offsets)
|
||||
D = values.shape[1]
|
||||
B = offsets.shape[0] - 1
|
||||
self._size = (B, ragged_dim, D)
|
||||
|
||||
Reference in New Issue
Block a user