[inductor] Add FLOAT_IS_NAN and COMPLEX_IS_NAN guards (#162537)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162537
Approved by: https://github.com/anijain2305, https://github.com/mlazos
ghstack dependencies: #162528
This commit is contained in:
Isuru Fernando
2025-09-11 18:36:36 +00:00
committed by PyTorch MergeBot
parent 5dd84559a5
commit 79d2418b5a
3 changed files with 65 additions and 10 deletions

View File

@ -325,6 +325,14 @@ class GuardManager:
level: int,
verbose_code_parts: list[str],
) -> None: ...
def add_float_is_nan_guard(
self,
verbose_code_parts: list[str],
) -> None: ...
def add_complex_is_nan_guard(
self,
verbose_code_parts: list[str],
) -> None: ...
def add_tuple_iterator_length_guard(
self,
length: int,

View File

@ -2197,26 +2197,20 @@ class GuardBuilder(GuardBuilderBase):
# Special case for nan because float("nan") == float("nan") evaluates to False
if istype(val, float) and math.isnan(val):
self.TYPE_MATCH(guard)
code = []
code.append(f"__math_isnan({ref})")
code = [f"(type({ref}) is float and __math_isnan({ref}))"]
self._set_guard_export_info(guard, code)
self.get_guard_manager(guard).add_lambda_guard_no_framelocals(
_get_closure_vars()["__math_isnan"], # type: ignore[arg-type]
self.get_guard_manager(guard).add_float_is_nan_guard(
get_verbose_code_parts(code, guard),
)
return
# Python math library doesn't support complex nan, so we need to use numpy
if istype(val, complex) and np.isnan(val):
self.TYPE_MATCH(guard)
code = []
code.append(f"__numpy_isnan({ref})")
code = [f"(type({ref}) is complex and __numpy_isnan({ref}))"]
self._set_guard_export_info(guard, code)
self.get_guard_manager(guard).add_lambda_guard_no_framelocals(
_get_closure_vars()["__numpy_isnan"], # type: ignore[arg-type]
self.get_guard_manager(guard).add_complex_is_nan_guard(
get_verbose_code_parts(code, guard),
)
return

View File

@ -2254,6 +2254,39 @@ class SET_CONTAINS : public LeafGuard {
py::object _item;
};
// Check if the float is nan
class FLOAT_IS_NAN : public LeafGuard {
public:
FLOAT_IS_NAN(
RootGuardManager* root_guard_manager,
py::object verbose_code_parts)
: LeafGuard(root_guard_manager, std::move(verbose_code_parts)) {}
bool check_nopybind(PyObject* value) override { // borrowed ref
if (!PyFloat_CheckExact(value)) {
return false;
}
return std::isnan(PyFloat_AsDouble(value));
}
};
// Check if the float is nan
class COMPLEX_IS_NAN : public LeafGuard {
public:
COMPLEX_IS_NAN(
RootGuardManager* root_guard_manager,
py::object verbose_code_parts)
: LeafGuard(root_guard_manager, std::move(verbose_code_parts)) {}
bool check_nopybind(PyObject* value) override { // borrowed ref
if (!PyComplex_CheckExact(value)) {
return false;
}
Py_complex c_value = PyComplex_AsCComplex(value);
return std::isnan(c_value.real) || std::isnan(c_value.imag);
}
};
// Check if the dual level is the same as the one in fx graph
class DUAL_LEVEL_MATCH : public LeafGuard {
public:
@ -6875,6 +6908,14 @@ PyObject* torch_c_dynamo_guards_init() {
py_m, "DUAL_LEVEL_MATCH")
.def(py::init<RootGuardManager*, int64_t, py::list>())
.def("__call__", &DUAL_LEVEL_MATCH::check);
py::class_<FLOAT_IS_NAN, LeafGuard, std::shared_ptr<FLOAT_IS_NAN>>(
py_m, "FLOAT_IS_NAN")
.def(py::init<RootGuardManager*, py::list>())
.def("__call__", &FLOAT_IS_NAN::check);
py::class_<COMPLEX_IS_NAN, LeafGuard, std::shared_ptr<COMPLEX_IS_NAN>>(
py_m, "COMPLEX_IS_NAN")
.def(py::init<RootGuardManager*, py::list>())
.def("__call__", &COMPLEX_IS_NAN::check);
py::class_<DYNAMIC_INDICES, LeafGuard, std::shared_ptr<DYNAMIC_INDICES>>(
py_m, "DYNAMIC_INDICES")
.def(py::init<RootGuardManager*, py::set, py::list>())
@ -7316,6 +7357,18 @@ PyObject* torch_c_dynamo_guards_init() {
self.add_leaf_guard(std::make_shared<DUAL_LEVEL_MATCH>(
self.get_root(), level, std::move(verbose_code_parts)));
})
.def(
"add_float_is_nan_guard",
[](GuardManager& self, py::object verbose_code_parts) -> void {
self.add_leaf_guard(std::make_shared<FLOAT_IS_NAN>(
self.get_root(), std::move(verbose_code_parts)));
})
.def(
"add_complex_is_nan_guard",
[](GuardManager& self, py::object verbose_code_parts) -> void {
self.add_leaf_guard(std::make_shared<COMPLEX_IS_NAN>(
self.get_root(), std::move(verbose_code_parts)));
})
.def(
"add_dynamic_indices_guard",
[](GuardManager& self,