mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
check fallback_value first. (#154493)
This is just a refactor, not a fix for any issue. we do check fallback_value first and early exit instead of checking it not set over and over. Pull Request resolved: https://github.com/pytorch/pytorch/pull/154493 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
447b481c79
commit
cd9ff41282
@ -7382,14 +7382,15 @@ class ShapeEnv:
|
||||
if not (new_expr.free_symbols <= self.var_to_val.keys()):
|
||||
ok = False
|
||||
|
||||
# TODO maybe deprecate this feature.
|
||||
# fallback_value is set when guard_or_true or guard_or_false are used.
|
||||
if not ok and fallback_value is not None:
|
||||
self._log_suppressed_dde(orig_expr, fallback_value)
|
||||
return fallback_value
|
||||
|
||||
# oblivious_var_to_val will be defined iff we have sizes with DimDynamic.OBLIVIOUS_SIZE type.
|
||||
# Here we handle falling back to the hint for dimensions of type DimDynamic.OBLIVIOUS_SIZE.
|
||||
# Those are backed dimentions that are treated as unbacked to avoid specializations, but if
|
||||
# we fail to bypass with size oblivious reasoning we compute using the actual hint and guard.
|
||||
# See https://github.com/pytorch/pytorch/issues/137100#issuecomment-2495778113
|
||||
if (
|
||||
fallback_value is None # do not do this under guard_or
|
||||
and self.oblivious_var_to_val
|
||||
self.oblivious_var_to_val
|
||||
and not (
|
||||
correct_hint := orig_expr.xreplace(
|
||||
self.oblivious_var_to_val
|
||||
@ -7418,10 +7419,8 @@ class ShapeEnv:
|
||||
# unbacked_var_to_val is not None iff propagate_real_tensors is on.
|
||||
# if propagate_real_tensors is on, we check the example values to generate (unsound_result)
|
||||
# and if they pass we add a runtime assertions and continue.
|
||||
|
||||
if (
|
||||
fallback_value is None # do not do this under guard_or
|
||||
and not ok
|
||||
not ok
|
||||
and self.unbacked_var_to_val
|
||||
and not (
|
||||
unsound_result := orig_expr.xreplace(
|
||||
@ -7435,18 +7434,12 @@ class ShapeEnv:
|
||||
ok = True
|
||||
|
||||
# Check if this is coming from a python assert statement, if so, convert it to a runtime assertion
|
||||
# if instead of failing.
|
||||
# instead of failing.
|
||||
if not ok and self.trace_asserts and self._is_python_assert():
|
||||
concrete_val = sympy.true
|
||||
transmute_into_runtime_assert = True
|
||||
ok = True
|
||||
|
||||
# fallback value is set when guard_or_true, gaurd_or_false are used.
|
||||
# whe we fail to evaluate soundly, we use the default value set by it.
|
||||
if not ok and fallback_value is not None:
|
||||
self._log_suppressed_dde(orig_expr, fallback_value)
|
||||
return fallback_value
|
||||
|
||||
if not ok:
|
||||
size_oblivious_result = None
|
||||
# compute size_oblivious_result to suggest it as a fix for the user if it works.
|
||||
|
Reference in New Issue
Block a user