mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[inductor] Add FLOAT_IS_NAN and COMPLEX_IS_NAN guards (#162537)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162537 Approved by: https://github.com/anijain2305, https://github.com/mlazos ghstack dependencies: #162528
This commit is contained in:
committed by
PyTorch MergeBot
parent
5dd84559a5
commit
79d2418b5a
@ -325,6 +325,14 @@ class GuardManager:
|
||||
level: int,
|
||||
verbose_code_parts: list[str],
|
||||
) -> None: ...
|
||||
def add_float_is_nan_guard(
|
||||
self,
|
||||
verbose_code_parts: list[str],
|
||||
) -> None: ...
|
||||
def add_complex_is_nan_guard(
|
||||
self,
|
||||
verbose_code_parts: list[str],
|
||||
) -> None: ...
|
||||
def add_tuple_iterator_length_guard(
|
||||
self,
|
||||
length: int,
|
||||
|
@ -2197,26 +2197,20 @@ class GuardBuilder(GuardBuilderBase):
|
||||
|
||||
# Special case for nan because float("nan") == float("nan") evaluates to False
|
||||
if istype(val, float) and math.isnan(val):
|
||||
self.TYPE_MATCH(guard)
|
||||
code = []
|
||||
code.append(f"__math_isnan({ref})")
|
||||
code = [f"(type({ref}) is float and __math_isnan({ref}))"]
|
||||
self._set_guard_export_info(guard, code)
|
||||
|
||||
self.get_guard_manager(guard).add_lambda_guard_no_framelocals(
|
||||
_get_closure_vars()["__math_isnan"], # type: ignore[arg-type]
|
||||
self.get_guard_manager(guard).add_float_is_nan_guard(
|
||||
get_verbose_code_parts(code, guard),
|
||||
)
|
||||
return
|
||||
|
||||
# Python math library doesn't support complex nan, so we need to use numpy
|
||||
if istype(val, complex) and np.isnan(val):
|
||||
self.TYPE_MATCH(guard)
|
||||
code = []
|
||||
code.append(f"__numpy_isnan({ref})")
|
||||
code = [f"(type({ref}) is complex and __numpy_isnan({ref}))"]
|
||||
self._set_guard_export_info(guard, code)
|
||||
|
||||
self.get_guard_manager(guard).add_lambda_guard_no_framelocals(
|
||||
_get_closure_vars()["__numpy_isnan"], # type: ignore[arg-type]
|
||||
self.get_guard_manager(guard).add_complex_is_nan_guard(
|
||||
get_verbose_code_parts(code, guard),
|
||||
)
|
||||
return
|
||||
|
@ -2254,6 +2254,39 @@ class SET_CONTAINS : public LeafGuard {
|
||||
py::object _item;
|
||||
};
|
||||
|
||||
// Check if the float is nan
|
||||
class FLOAT_IS_NAN : public LeafGuard {
|
||||
public:
|
||||
FLOAT_IS_NAN(
|
||||
RootGuardManager* root_guard_manager,
|
||||
py::object verbose_code_parts)
|
||||
: LeafGuard(root_guard_manager, std::move(verbose_code_parts)) {}
|
||||
|
||||
bool check_nopybind(PyObject* value) override { // borrowed ref
|
||||
if (!PyFloat_CheckExact(value)) {
|
||||
return false;
|
||||
}
|
||||
return std::isnan(PyFloat_AsDouble(value));
|
||||
}
|
||||
};
|
||||
|
||||
// Check if the float is nan
|
||||
class COMPLEX_IS_NAN : public LeafGuard {
|
||||
public:
|
||||
COMPLEX_IS_NAN(
|
||||
RootGuardManager* root_guard_manager,
|
||||
py::object verbose_code_parts)
|
||||
: LeafGuard(root_guard_manager, std::move(verbose_code_parts)) {}
|
||||
|
||||
bool check_nopybind(PyObject* value) override { // borrowed ref
|
||||
if (!PyComplex_CheckExact(value)) {
|
||||
return false;
|
||||
}
|
||||
Py_complex c_value = PyComplex_AsCComplex(value);
|
||||
return std::isnan(c_value.real) || std::isnan(c_value.imag);
|
||||
}
|
||||
};
|
||||
|
||||
// Check if the dual level is the same as the one in fx graph
|
||||
class DUAL_LEVEL_MATCH : public LeafGuard {
|
||||
public:
|
||||
@ -6875,6 +6908,14 @@ PyObject* torch_c_dynamo_guards_init() {
|
||||
py_m, "DUAL_LEVEL_MATCH")
|
||||
.def(py::init<RootGuardManager*, int64_t, py::list>())
|
||||
.def("__call__", &DUAL_LEVEL_MATCH::check);
|
||||
py::class_<FLOAT_IS_NAN, LeafGuard, std::shared_ptr<FLOAT_IS_NAN>>(
|
||||
py_m, "FLOAT_IS_NAN")
|
||||
.def(py::init<RootGuardManager*, py::list>())
|
||||
.def("__call__", &FLOAT_IS_NAN::check);
|
||||
py::class_<COMPLEX_IS_NAN, LeafGuard, std::shared_ptr<COMPLEX_IS_NAN>>(
|
||||
py_m, "COMPLEX_IS_NAN")
|
||||
.def(py::init<RootGuardManager*, py::list>())
|
||||
.def("__call__", &COMPLEX_IS_NAN::check);
|
||||
py::class_<DYNAMIC_INDICES, LeafGuard, std::shared_ptr<DYNAMIC_INDICES>>(
|
||||
py_m, "DYNAMIC_INDICES")
|
||||
.def(py::init<RootGuardManager*, py::set, py::list>())
|
||||
@ -7316,6 +7357,18 @@ PyObject* torch_c_dynamo_guards_init() {
|
||||
self.add_leaf_guard(std::make_shared<DUAL_LEVEL_MATCH>(
|
||||
self.get_root(), level, std::move(verbose_code_parts)));
|
||||
})
|
||||
.def(
|
||||
"add_float_is_nan_guard",
|
||||
[](GuardManager& self, py::object verbose_code_parts) -> void {
|
||||
self.add_leaf_guard(std::make_shared<FLOAT_IS_NAN>(
|
||||
self.get_root(), std::move(verbose_code_parts)));
|
||||
})
|
||||
.def(
|
||||
"add_complex_is_nan_guard",
|
||||
[](GuardManager& self, py::object verbose_code_parts) -> void {
|
||||
self.add_leaf_guard(std::make_shared<COMPLEX_IS_NAN>(
|
||||
self.get_root(), std::move(verbose_code_parts)));
|
||||
})
|
||||
.def(
|
||||
"add_dynamic_indices_guard",
|
||||
[](GuardManager& self,
|
||||
|
Reference in New Issue
Block a user