check fallback_value first. (#154493)

This is just a refactor, not a fix for any issue.
we do check fallback_value first  and early exit instead of checking it not set over and over.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154493
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Laith Sakka
2025-05-28 21:12:55 -07:00
committed by PyTorch MergeBot
parent 447b481c79
commit cd9ff41282

View File

@ -7382,14 +7382,15 @@ class ShapeEnv:
if not (new_expr.free_symbols <= self.var_to_val.keys()):
ok = False
# TODO maybe deprecate this feature.
# fallback_value is set when guard_or_true or guard_or_false are used.
if not ok and fallback_value is not None:
self._log_suppressed_dde(orig_expr, fallback_value)
return fallback_value
# oblivious_var_to_val will be defined iff we have sizes with DimDynamic.OBLIVIOUS_SIZE type.
# Here we handle falling back to the hint for dimensions of type DimDynamic.OBLIVIOUS_SIZE.
# Those are backed dimentions that are treated as unbacked to avoid specializations, but if
# we fail to bypass with size oblivious reasoning we compute using the actual hint and guard.
# See https://github.com/pytorch/pytorch/issues/137100#issuecomment-2495778113
if (
fallback_value is None # do not do this under guard_or
and self.oblivious_var_to_val
self.oblivious_var_to_val
and not (
correct_hint := orig_expr.xreplace(
self.oblivious_var_to_val
@ -7418,10 +7419,8 @@ class ShapeEnv:
# unbacked_var_to_val is not None iff propagate_real_tensors is on.
# if propagate_real_tensors is on, we check the example values to generate (unsound_result)
# and if they pass we add a runtime assertions and continue.
if (
fallback_value is None # do not do this under guard_or
and not ok
not ok
and self.unbacked_var_to_val
and not (
unsound_result := orig_expr.xreplace(
@ -7435,18 +7434,12 @@ class ShapeEnv:
ok = True
# Check if this is coming from a python assert statement, if so, convert it to a runtime assertion
# if instead of failing.
# instead of failing.
if not ok and self.trace_asserts and self._is_python_assert():
concrete_val = sympy.true
transmute_into_runtime_assert = True
ok = True
# fallback value is set when guard_or_true, gaurd_or_false are used.
# whe we fail to evaluate soundly, we use the default value set by it.
if not ok and fallback_value is not None:
self._log_suppressed_dde(orig_expr, fallback_value)
return fallback_value
if not ok:
size_oblivious_result = None
# compute size_oblivious_result to suggest it as a fix for the user if it works.