[dynamo][guards][serialization] Dont use ID_MATCH guard for bool and None (#149228)

Doing this removes the need of collecting `id` and therefore facilitates serialization. It also improves readability with recompilations. Earlier, recompile message will just show the `id`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149228
Approved by: https://github.com/jansel
This commit is contained in:
Animesh Jain
2025-03-14 14:12:21 -07:00
committed by PyTorch MergeBot
parent 186cc7327c
commit f9a787224c
2 changed files with 97 additions and 1 deletions

View File

@ -1464,6 +1464,35 @@ class GuardBuilder(GuardBuilderBase):
not invert, key, get_verbose_code_parts(code, guard)
)
def BOOL_MATCH(self, guard: Guard):
# checks val == True or val == False
ref = self.arg_ref(guard)
val = self.get(guard.name)
assert istype(val, bool)
code = [f"{ref} == {val!r}"]
self._set_guard_export_info(guard, code)
if val:
self.get_guard_manager(guard).add_true_match_guard(
get_verbose_code_parts(code, guard)
)
else:
self.get_guard_manager(guard).add_false_match_guard(
get_verbose_code_parts(code, guard)
)
def NONE_MATCH(self, guard: Guard):
# checks `val is None`
ref = self.arg_ref(guard)
val = self.get(guard.name)
assert val is None
code = [f"{ref} is None"]
self._set_guard_export_info(guard, code)
self.get_guard_manager(guard).add_none_match_guard(
get_verbose_code_parts(code, guard)
)
def ID_MATCH(self, guard: Guard):
# ___check_obj_id is same as `id(x) == y`
if isinstance(guard.originating_source, TypeSource):
@ -1682,7 +1711,11 @@ class GuardBuilder(GuardBuilderBase):
def CONSTANT_MATCH(self, guard: Guard):
val = self.get(guard.name)
if istype(val, (bool, type(None), types.CodeType)):
if istype(val, bool):
self.BOOL_MATCH(guard)
elif val is None:
self.NONE_MATCH(guard)
elif istype(val, types.CodeType):
self.ID_MATCH(guard)
else:
self.EQUALS_MATCH(guard)

View File

@ -1486,6 +1486,36 @@ class ID_MATCH : public LeafGuard {
intptr_t _expected;
};
class NONE_MATCH : public LeafGuard {
public:
NONE_MATCH(py::object verbose_code_parts)
: LeafGuard(std::move(verbose_code_parts)) {}
bool check_nopybind(PyObject* value) override { // borrowed ref
return value == Py_None;
}
};
class TRUE_MATCH : public LeafGuard {
public:
TRUE_MATCH(py::object verbose_code_parts)
: LeafGuard(std::move(verbose_code_parts)) {}
bool check_nopybind(PyObject* value) override { // borrowed ref
return value == Py_True;
}
};
class FALSE_MATCH : public LeafGuard {
public:
FALSE_MATCH(py::object verbose_code_parts)
: LeafGuard(std::move(verbose_code_parts)) {}
bool check_nopybind(PyObject* value) override { // borrowed ref
return value == Py_False;
}
};
class EQUALS_MATCH : public LeafGuard {
public:
EQUALS_MATCH(py::object value, py::object verbose_code_parts)
@ -5233,6 +5263,18 @@ PyObject* torch_c_dynamo_guards_init() {
py::class_<ID_MATCH, LeafGuard, std::shared_ptr<ID_MATCH>>(py_m, "ID_MATCH")
.def(py::init<py::object, py::list>())
.def("__call__", &ID_MATCH::check);
py::class_<NONE_MATCH, LeafGuard, std::shared_ptr<NONE_MATCH>>(
py_m, "NONE_MATCH")
.def(py::init<py::list>())
.def("__call__", &NONE_MATCH::check);
py::class_<TRUE_MATCH, LeafGuard, std::shared_ptr<TRUE_MATCH>>(
py_m, "TRUE_MATCH")
.def(py::init<py::list>())
.def("__call__", &TRUE_MATCH::check);
py::class_<FALSE_MATCH, LeafGuard, std::shared_ptr<FALSE_MATCH>>(
py_m, "FALSE_MATCH")
.def(py::init<py::list>())
.def("__call__", &FALSE_MATCH::check);
py::class_<EQUALS_MATCH, LeafGuard, std::shared_ptr<EQUALS_MATCH>>(
py_m, "EQUALS_MATCH")
.def(py::init<py::object, py::list>())
@ -5478,6 +5520,27 @@ PyObject* torch_c_dynamo_guards_init() {
self.add_leaf_guard(std::make_shared<ID_MATCH>(
std::move(value), std::move(verbose_code_parts)));
})
.def(
"add_none_match_guard",
[](GuardManager& self, py::object verbose_code_parts) -> void {
SKIP_IF_GUARD_ALREADY_PRESENT("NONE_MATCH");
self.add_leaf_guard(
std::make_shared<NONE_MATCH>(std::move(verbose_code_parts)));
})
.def(
"add_true_match_guard",
[](GuardManager& self, py::object verbose_code_parts) -> void {
SKIP_IF_GUARD_ALREADY_PRESENT("TRUE_MATCH");
self.add_leaf_guard(
std::make_shared<TRUE_MATCH>(std::move(verbose_code_parts)));
})
.def(
"add_false_match_guard",
[](GuardManager& self, py::object verbose_code_parts) -> void {
SKIP_IF_GUARD_ALREADY_PRESENT("FALSE_MATCH");
self.add_leaf_guard(
std::make_shared<FALSE_MATCH>(std::move(verbose_code_parts)));
})
.def(
"add_equals_match_guard",
[](GuardManager& self,