mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add is_nested_int() (#119975)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119975 Approved by: https://github.com/jbschlosser ghstack dependencies: #119661, #119974
This commit is contained in:
committed by
PyTorch MergeBot
parent
2e77629b9f
commit
27c5bbe5cb
@ -6,7 +6,7 @@ namespace c10 {
|
||||
|
||||
namespace {
|
||||
bool _eq(const char* op, c10::SymNodeImpl* lhs, c10::SymNodeImpl* rhs) {
|
||||
TORCH_INTERNAL_ASSERT(lhs->nested_int().has_value());
|
||||
TORCH_INTERNAL_ASSERT(lhs->is_nested_int());
|
||||
c10::optional<int64_t> c = rhs->nested_int();
|
||||
return (
|
||||
c.has_value() && lhs->nested_int() == *c &&
|
||||
|
@ -57,6 +57,10 @@ class TORCH_API NestedIntSymNodeImpl : public SymNodeImpl {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool is_nested_int() const override {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool has_hint() override {
|
||||
return true;
|
||||
}
|
||||
|
@ -3,15 +3,15 @@
|
||||
namespace c10 {
|
||||
|
||||
// This is used to support the case where the lhs is a constant symnode
|
||||
// and the rhs is a singleton symnode. This situation occurs today when we
|
||||
// and the rhs is a nested int symnode. This situation occurs today when we
|
||||
// perform a binary op between nested int and plain int and the
|
||||
// singleton promotes the int into a constant symnode. If we'd like to
|
||||
// int is promoted into a constant symnode. If we'd like to
|
||||
// support more combinations in the future, we may need to implement some
|
||||
// kind of multiple dispatch.
|
||||
#define DEFINE_BINARY_OP(OP, ROP) \
|
||||
template <typename T> \
|
||||
c10::SymNode ConstantSymNodeImpl<T>::OP(const c10::SymNode& other) { \
|
||||
TORCH_INTERNAL_ASSERT(other->nested_int().has_value()); \
|
||||
TORCH_INTERNAL_ASSERT(other->is_nested_int()); \
|
||||
return other->ROP( \
|
||||
c10::intrusive_ptr<ConstantSymNodeImpl<T>>::reclaim_copy(this)); \
|
||||
}
|
||||
|
@ -37,6 +37,9 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
|
||||
virtual bool is_float() {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual bool is_nested_int() const {
|
||||
return false;
|
||||
};
|
||||
virtual SymNode add(const SymNode& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
|
@ -301,9 +301,8 @@ class SymInt:
|
||||
return str(self.node)
|
||||
|
||||
def __hash__(self) -> builtins.int:
|
||||
ret = self.node.nested_int()
|
||||
if ret is not None:
|
||||
return hash(ret)
|
||||
if self.node.is_nested_int():
|
||||
return hash(self.node.nested_int())
|
||||
else:
|
||||
# We could support constant SymInts as well, but not doing it for now
|
||||
raise TypeError("unhashable type: non-nested SymInt")
|
||||
|
@ -1284,6 +1284,11 @@ void initJITBindings(PyObject* module) {
|
||||
[](const c10::SymNode& node){
|
||||
return node->is_constant();
|
||||
})
|
||||
.def(
|
||||
"is_nested_int",
|
||||
[](const c10::SymNode& node) {
|
||||
return node->is_nested_int();
|
||||
})
|
||||
.def(
|
||||
"is_symbolic",
|
||||
[](const c10::SymNode& node) {
|
||||
|
@ -95,6 +95,11 @@ class PythonSymNodeImpl : public c10::SymNodeImpl {
|
||||
return getPyObj().attr("is_bool")().is(py::handle(Py_True));
|
||||
}
|
||||
|
||||
bool is_nested_int() const override {
|
||||
py::gil_scoped_acquire acquire;
|
||||
return getPyObj().attr("is_nested_int")().is(py::handle(Py_True));
|
||||
}
|
||||
|
||||
bool has_hint() override {
|
||||
py::gil_scoped_acquire acquire;
|
||||
return getPyObj().attr("has_hint")().is(py::handle(Py_True));
|
||||
@ -273,8 +278,8 @@ class PythonSymNodeImpl : public c10::SymNodeImpl {
|
||||
return dispatch_common_(__func__);
|
||||
}
|
||||
|
||||
py::handle getPyObj() {
|
||||
return py::handle(pyobj_.get()->ptr(getPyInterpreter()));
|
||||
py::handle getPyObj() const {
|
||||
return py::handle(pyobj_->ptr(getPyInterpreter()));
|
||||
}
|
||||
std::shared_ptr<c10::SafePyObject> pyobj_ = nullptr;
|
||||
};
|
||||
|
@ -173,6 +173,14 @@ class SymNode:
|
||||
def is_bool(self):
|
||||
return self.pytype is bool
|
||||
|
||||
def is_nested_int(self):
|
||||
# Unbacked SymInts cannot be nested int today
|
||||
return (
|
||||
self._hint is not None
|
||||
and isinstance(self._hint, SymInt)
|
||||
and self._hint.node.is_nested_int()
|
||||
)
|
||||
|
||||
def wrap_int(self, num):
|
||||
assert type(num) is int
|
||||
import sympy
|
||||
|
@ -263,19 +263,7 @@ def is_concrete_bool(a: Union[bool, SymBool]) -> bool:
|
||||
return False
|
||||
|
||||
def is_nested_int(s):
|
||||
# check for NestedIntSymNode
|
||||
if not isinstance(s, torch.SymInt):
|
||||
return False
|
||||
if s.node.nested_int() is not None:
|
||||
return True
|
||||
|
||||
# check for symbolic variable wrapping a NestedIntSymNode (fake-ifying causes this)
|
||||
return (
|
||||
s.node.is_symbolic()
|
||||
and s.node.hint is not None
|
||||
and isinstance(s.node.hint, torch.SymInt)
|
||||
and s.node.hint.node.nested_int() is not None
|
||||
)
|
||||
return isinstance(s, torch.SymInt) and s.node.is_nested_int()
|
||||
|
||||
def _iterate_exprs(val: Union[SymInt, torch.Tensor]) -> Iterable[sympy.Basic]:
|
||||
if isinstance(val, SymTypes):
|
||||
|
Reference in New Issue
Block a user