Add is_nested_int() (#119975)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119975
Approved by: https://github.com/jbschlosser
ghstack dependencies: #119661, #119974
This commit is contained in:
soulitzer
2024-02-19 15:49:02 -05:00
committed by PyTorch MergeBot
parent 2e77629b9f
commit 27c5bbe5cb
9 changed files with 34 additions and 22 deletions

View File

@ -6,7 +6,7 @@ namespace c10 {
namespace {
bool _eq(const char* op, c10::SymNodeImpl* lhs, c10::SymNodeImpl* rhs) {
TORCH_INTERNAL_ASSERT(lhs->nested_int().has_value());
TORCH_INTERNAL_ASSERT(lhs->is_nested_int());
c10::optional<int64_t> c = rhs->nested_int();
return (
c.has_value() && lhs->nested_int() == *c &&

View File

@ -57,6 +57,10 @@ class TORCH_API NestedIntSymNodeImpl : public SymNodeImpl {
return false;
}
bool is_nested_int() const override {
return true;
}
bool has_hint() override {
return true;
}

View File

@ -3,15 +3,15 @@
namespace c10 {
// This is used to support the case where the lhs is a constant symnode
// and the rhs is a singleton symnode. This situation occurs today when we
// and the rhs is a nested int symnode. This situation occurs today when we
// perform a binary op between nested int and plain int and the
// singleton promotes the int into a constant symnode. If we'd like to
// int is promoted into a constant symnode. If we'd like to
// support more combinations in the future, we may need to implement some
// kind of multiple dispatch.
#define DEFINE_BINARY_OP(OP, ROP) \
template <typename T> \
c10::SymNode ConstantSymNodeImpl<T>::OP(const c10::SymNode& other) { \
TORCH_INTERNAL_ASSERT(other->nested_int().has_value()); \
TORCH_INTERNAL_ASSERT(other->is_nested_int()); \
return other->ROP( \
c10::intrusive_ptr<ConstantSymNodeImpl<T>>::reclaim_copy(this)); \
}

View File

@ -37,6 +37,9 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
virtual bool is_float() {
TORCH_CHECK(false, "NYI");
};
virtual bool is_nested_int() const {
return false;
};
virtual SymNode add(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};

View File

@ -301,9 +301,8 @@ class SymInt:
return str(self.node)
def __hash__(self) -> builtins.int:
ret = self.node.nested_int()
if ret is not None:
return hash(ret)
if self.node.is_nested_int():
return hash(self.node.nested_int())
else:
# We could support constant SymInts as well, but not doing it for now
raise TypeError("unhashable type: non-nested SymInt")

View File

@ -1284,6 +1284,11 @@ void initJITBindings(PyObject* module) {
[](const c10::SymNode& node){
return node->is_constant();
})
.def(
"is_nested_int",
[](const c10::SymNode& node) {
return node->is_nested_int();
})
.def(
"is_symbolic",
[](const c10::SymNode& node) {

View File

@ -95,6 +95,11 @@ class PythonSymNodeImpl : public c10::SymNodeImpl {
return getPyObj().attr("is_bool")().is(py::handle(Py_True));
}
bool is_nested_int() const override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("is_nested_int")().is(py::handle(Py_True));
}
bool has_hint() override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("has_hint")().is(py::handle(Py_True));
@ -273,8 +278,8 @@ class PythonSymNodeImpl : public c10::SymNodeImpl {
return dispatch_common_(__func__);
}
py::handle getPyObj() {
return py::handle(pyobj_.get()->ptr(getPyInterpreter()));
py::handle getPyObj() const {
return py::handle(pyobj_->ptr(getPyInterpreter()));
}
std::shared_ptr<c10::SafePyObject> pyobj_ = nullptr;
};

View File

@ -173,6 +173,14 @@ class SymNode:
def is_bool(self):
return self.pytype is bool
def is_nested_int(self):
# Unbacked SymInts cannot be nested int today
return (
self._hint is not None
and isinstance(self._hint, SymInt)
and self._hint.node.is_nested_int()
)
def wrap_int(self, num):
assert type(num) is int
import sympy

View File

@ -263,19 +263,7 @@ def is_concrete_bool(a: Union[bool, SymBool]) -> bool:
return False
def is_nested_int(s):
# check for NestedIntSymNode
if not isinstance(s, torch.SymInt):
return False
if s.node.nested_int() is not None:
return True
# check for symbolic variable wrapping a NestedIntSymNode (fake-ifying causes this)
return (
s.node.is_symbolic()
and s.node.hint is not None
and isinstance(s.node.hint, torch.SymInt)
and s.node.hint.node.nested_int() is not None
)
return isinstance(s, torch.SymInt) and s.node.is_nested_int()
def _iterate_exprs(val: Union[SymInt, torch.Tensor]) -> Iterable[sympy.Basic]:
if isinstance(val, SymTypes):