mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] Add FLOAT_IS_NAN and COMPLEX_IS_NAN guards (#162537)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162537 Approved by: https://github.com/anijain2305, https://github.com/mlazos ghstack dependencies: #162528
This commit is contained in:
committed by
PyTorch MergeBot
parent
5dd84559a5
commit
79d2418b5a
@ -2197,26 +2197,20 @@ class GuardBuilder(GuardBuilderBase):
|
||||
|
||||
# Special case for nan because float("nan") == float("nan") evaluates to False
|
||||
if istype(val, float) and math.isnan(val):
|
||||
self.TYPE_MATCH(guard)
|
||||
code = []
|
||||
code.append(f"__math_isnan({ref})")
|
||||
code = [f"(type({ref}) is float and __math_isnan({ref}))"]
|
||||
self._set_guard_export_info(guard, code)
|
||||
|
||||
self.get_guard_manager(guard).add_lambda_guard_no_framelocals(
|
||||
_get_closure_vars()["__math_isnan"], # type: ignore[arg-type]
|
||||
self.get_guard_manager(guard).add_float_is_nan_guard(
|
||||
get_verbose_code_parts(code, guard),
|
||||
)
|
||||
return
|
||||
|
||||
# Python math library doesn't support complex nan, so we need to use numpy
|
||||
if istype(val, complex) and np.isnan(val):
|
||||
self.TYPE_MATCH(guard)
|
||||
code = []
|
||||
code.append(f"__numpy_isnan({ref})")
|
||||
code = [f"(type({ref}) is complex and __numpy_isnan({ref}))"]
|
||||
self._set_guard_export_info(guard, code)
|
||||
|
||||
self.get_guard_manager(guard).add_lambda_guard_no_framelocals(
|
||||
_get_closure_vars()["__numpy_isnan"], # type: ignore[arg-type]
|
||||
self.get_guard_manager(guard).add_complex_is_nan_guard(
|
||||
get_verbose_code_parts(code, guard),
|
||||
)
|
||||
return
|
||||
|
Reference in New Issue
Block a user