consolidate guard_or_x and definitely_x (#152463)

definitely_true is almost same as guard_or_false, the potential differences are not meaningful to a degree that justify the
existence of both. same for definitely_false, it can be expressed with guard_or_true and guard_or_false.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152463
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Laith Sakka
2025-04-29 13:52:43 -07:00
committed by PyTorch MergeBot
parent 72337bdcf2
commit 376529c78b
10 changed files with 43 additions and 85 deletions

View File

@ -419,11 +419,4 @@ inline SymBool sym_ge(const SymInt& a, const SymInt& b) {
return a.sym_ge(b); return a.sym_ge(b);
} }
inline bool definitely_true(
const c10::SymBool& b,
const char* file,
int64_t line) {
return b.has_hint() && b.guard_bool(file, line);
}
} // namespace c10 } // namespace c10

View File

@ -128,11 +128,11 @@ DEFINE_SYMBOOL_COMPUTE(compute_non_overlapping_and_dense, is_non_overlapping_and
SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_dim4() const { SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_dim4() const {
init_is_contiguous(); init_is_contiguous();
if (definitely_true(is_contiguous(), __FILE__, __LINE__)) { if (guard_or_false(is_contiguous(), __FILE__, __LINE__)) {
return true; return true;
} }
init_is_channels_last_contiguous(); init_is_channels_last_contiguous();
if (definitely_true(is_channels_last_contiguous(), __FILE__, __LINE__)) { if (guard_or_false(is_channels_last_contiguous(), __FILE__, __LINE__)) {
return true; return true;
} }
return is_contiguous() | is_channels_last_contiguous() | return is_contiguous() | is_channels_last_contiguous() |
@ -141,7 +141,7 @@ SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_dim4() const {
SymBool SymbolicShapeMeta::compute_channels_last_contiguous_3d_dim5() const { SymBool SymbolicShapeMeta::compute_channels_last_contiguous_3d_dim5() const {
init_is_channels_last_contiguous(); init_is_channels_last_contiguous();
if (definitely_true(is_channels_last_contiguous(), __FILE__, __LINE__)) { if (guard_or_false(is_channels_last_contiguous(), __FILE__, __LINE__)) {
return false; return false;
} }
return ~is_channels_last_contiguous() & compute_channels_last_contiguous_3d(); return ~is_channels_last_contiguous() & compute_channels_last_contiguous_3d();
@ -149,7 +149,7 @@ SymBool SymbolicShapeMeta::compute_channels_last_contiguous_3d_dim5() const {
SymBool SymbolicShapeMeta::compute_channels_last_2d_dim5() const { SymBool SymbolicShapeMeta::compute_channels_last_2d_dim5() const {
init_is_channels_last_3d_contiguous(); init_is_channels_last_3d_contiguous();
if (definitely_true(is_channels_last_3d_contiguous(), __FILE__, __LINE__)) { if (guard_or_false(is_channels_last_3d_contiguous(), __FILE__, __LINE__)) {
return false; return false;
} }
return ~is_channels_last_3d_contiguous() & return ~is_channels_last_3d_contiguous() &
@ -157,20 +157,20 @@ SymBool SymbolicShapeMeta::compute_channels_last_2d_dim5() const {
} }
SymBool SymbolicShapeMeta::compute_channels_last_3d_dim5() const { SymBool SymbolicShapeMeta::compute_channels_last_3d_dim5() const {
if (definitely_true(is_channels_last(), __FILE__, __LINE__)) { if (guard_or_false(is_channels_last(), __FILE__, __LINE__)) {
return false; return false;
} }
return ~is_channels_last() & compute_strides_like_channels_last_3d(); return ~is_channels_last() & compute_strides_like_channels_last_3d();
} }
SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_dim5() const { SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_dim5() const {
if (definitely_true(is_contiguous(), __FILE__, __LINE__)) { if (guard_or_false(is_contiguous(), __FILE__, __LINE__)) {
return true; return true;
} }
if (definitely_true(is_channels_last_contiguous(), __FILE__, __LINE__)) { if (guard_or_false(is_channels_last_contiguous(), __FILE__, __LINE__)) {
return true; return true;
} }
if (definitely_true(is_channels_last_3d_contiguous(), __FILE__, __LINE__)) { if (guard_or_false(is_channels_last_3d_contiguous(), __FILE__, __LINE__)) {
return true; return true;
} }
return is_contiguous() | is_channels_last_contiguous() | return is_contiguous() | is_channels_last_contiguous() |
@ -178,7 +178,7 @@ SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_dim5() const {
} }
SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_anydim() const { SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_anydim() const {
if (definitely_true(is_contiguous(), __FILE__, __LINE__)) { if (guard_or_false(is_contiguous(), __FILE__, __LINE__)) {
return true; return true;
} }
return is_contiguous() | compute_non_overlapping_and_dense(); return is_contiguous() | compute_non_overlapping_and_dense();

View File

@ -38,8 +38,6 @@ torch.fx.experimental.symbolic_shapes
is_concrete_float is_concrete_float
has_free_symbols has_free_symbols
has_free_unbacked_symbols has_free_unbacked_symbols
definitely_true
definitely_false
guard_or_true guard_or_true
guard_or_false guard_or_false
guard_size_oblivious guard_size_oblivious

View File

@ -2007,8 +2007,6 @@
"cast_symbool_to_symint_guardless", "cast_symbool_to_symint_guardless",
"constrain_range", "constrain_range",
"constrain_unify", "constrain_unify",
"definitely_false",
"definitely_true",
"guard_or_true", "guard_or_true",
"guard_or_false", "guard_or_false",
"error", "error",

View File

@ -17,11 +17,7 @@ from torch._logging import getArtifactLogger
from torch._subclasses.fake_tensor import FakeTensor from torch._subclasses.fake_tensor import FakeTensor
from torch._subclasses.functional_tensor import FunctionalTensor from torch._subclasses.functional_tensor import FunctionalTensor
from torch._subclasses.meta_utils import is_sparse_any from torch._subclasses.meta_utils import is_sparse_any
from torch.fx.experimental.symbolic_shapes import ( from torch.fx.experimental.symbolic_shapes import guard_or_false, sym_eq, SymIntEqByExpr
definitely_true,
sym_eq,
SymIntEqByExpr,
)
from torch.multiprocessing.reductions import StorageWeakRef from torch.multiprocessing.reductions import StorageWeakRef
from torch.utils._python_dispatch import ( from torch.utils._python_dispatch import (
is_traceable_wrapper_subclass, is_traceable_wrapper_subclass,
@ -317,13 +313,13 @@ def gen_alias_from_base(
def has_same_metadata(t1, t2): def has_same_metadata(t1, t2):
return ( return (
definitely_true(sym_eq(t1.size(), t2.size())) guard_or_false(sym_eq(t1.size(), t2.size()))
and definitely_true(t1.layout == t2.layout) and guard_or_false(t1.layout == t2.layout)
and ( and (
is_sparse_any(t1) is_sparse_any(t1)
or ( or (
definitely_true(sym_eq(t1.stride(), t2.stride())) guard_or_false(sym_eq(t1.stride(), t2.stride()))
and definitely_true(t1.storage_offset() == t2.storage_offset()) and guard_or_false(t1.storage_offset() == t2.storage_offset())
) )
) )
and t1.is_conj() == t2.is_conj() and t1.is_conj() == t2.is_conj()

View File

@ -29,7 +29,7 @@ from torch.fx.experimental.proxy_tensor import (
maybe_enable_thunkify, maybe_enable_thunkify,
) )
from torch.fx.experimental.symbolic_shapes import ( from torch.fx.experimental.symbolic_shapes import (
definitely_false, guard_or_true,
PropagateUnbackedSymInts, PropagateUnbackedSymInts,
sym_eq, sym_eq,
) )
@ -221,12 +221,15 @@ def create_joint(fn: Callable, *, aot_config: AOTConfig) -> Any:
# A bit sketchy, but fixes e.g. test_aot_autograd_exhaustive_matmul_cpu_float32 # A bit sketchy, but fixes e.g. test_aot_autograd_exhaustive_matmul_cpu_float32
# The issue is that we are sensitive to decomps that don't accurately maintain # The issue is that we are sensitive to decomps that don't accurately maintain
# their output's _base.shape compared to eager mode, and this helps mitigate a bit. # their output's _base.shape compared to eager mode, and this helps mitigate a bit.
# The not definitely_false is also sketchy; if unbacked # The guard_or_true also sketchy; if unbacked
# symints are involved, we're just going to assume that the # symints are involved, we're just going to assume that the
# decomps setup the base shape correctly # decomps setup the base shape correctly
# Return out if the result of out.shape==tangent.shape is unknown or known to be true.
# otherwise if its a known false return out.view(tangent.shape).
needed_outs.append( needed_outs.append(
out out
if not definitely_false(sym_eq(out.shape, tangent.shape)) if guard_or_true(sym_eq(out.shape, tangent.shape))
else out.view(tangent.shape) else out.view(tangent.shape)
) )
needed_tangents.append(tangent) needed_tangents.append(tangent)

View File

@ -33,7 +33,7 @@ from torch._prims_common import (
type_to_dtype, type_to_dtype,
) )
from torch.fx.experimental.symbolic_shapes import ( from torch.fx.experimental.symbolic_shapes import (
definitely_true, guard_or_false,
guard_size_oblivious, guard_size_oblivious,
statically_known_true, statically_known_true,
) )
@ -305,8 +305,8 @@ def addmm(
return alpha * out + beta * self return alpha * out + beta * self
if ( if (
statically_known_true(mat1.size(0) == 1) statically_known_true(mat1.size(0) == 1)
and definitely_true(mat2.size(0) <= 16) and guard_or_false(mat2.size(0) <= 16)
and definitely_true(mat2.size(1) <= 16) and guard_or_false(mat2.size(1) <= 16)
): ):
counters["inductor"]["decompose_addmm"] += 1 counters["inductor"]["decompose_addmm"] += 1
out = (mat1.T * mat2).sum(dim=0, keepdim=True) out = (mat1.T * mat2).sum(dim=0, keepdim=True)
@ -336,7 +336,7 @@ def mm(
and statically_known_true(self.size(0) > 0) and statically_known_true(self.size(0) > 0)
and statically_known_true(input2.size(0) == 1) and statically_known_true(input2.size(0) == 1)
and (self.dtype == input2.dtype) and (self.dtype == input2.dtype)
and definitely_true((torch.numel(self) + torch.numel(input2)) <= 32) and guard_or_false((torch.numel(self) + torch.numel(input2)) <= 32)
): ):
counters["inductor"]["decompose_mm"] += 1 counters["inductor"]["decompose_mm"] += 1
return torch.cat([self[i, :] * input2 for i in range(self.size(0))]) return torch.cat([self[i, :] * input2 for i in range(self.size(0))])

View File

@ -924,7 +924,7 @@ def infer_size(shape: ShapeType, numel: int) -> tuple[int, ...]:
Infers the size of a dim with size -1, if it exists. Infers the size of a dim with size -1, if it exists.
Also checks that new shape is compatible with the number of elements. Also checks that new shape is compatible with the number of elements.
""" """
from torch.fx.experimental.symbolic_shapes import definitely_true, guard_or_false from torch.fx.experimental.symbolic_shapes import guard_or_false
dim = None dim = None
newsize = 1 newsize = 1
@ -952,7 +952,7 @@ def infer_size(shape: ShapeType, numel: int) -> tuple[int, ...]:
lambda: ( lambda: (
f"cannot reshape tensor of 0 elements into shape {list(shape)} because the " f"cannot reshape tensor of 0 elements into shape {list(shape)} because the "
f"unspecified dimension size -1 can be any value and is ambiguous" f"unspecified dimension size -1 can be any value and is ambiguous"
if definitely_true(numel == 0) if guard_or_false(numel == 0)
else f"shape '{list(shape)}' is invalid for input of size {numel}" else f"shape '{list(shape)}' is invalid for input of size {numel}"
), ),
) )

View File

@ -94,7 +94,7 @@ at::Tensor InputMetadata::maybe_reduce(
if (TORCH_GUARD_SIZE_OBLIVIOUS(size.sym_eq(1))) { if (TORCH_GUARD_SIZE_OBLIVIOUS(size.sym_eq(1))) {
// NB: we could short circuit this once needs_reduce is true but there's // NB: we could short circuit this once needs_reduce is true but there's
// no point since the reduction function will guard on this anyway // no point since the reduction function will guard on this anyway
if (!c10::definitely_true(size.sym_eq(target), __FILE__, __LINE__)) { if (!c10::guard_or_false(size.sym_eq(target), __FILE__, __LINE__)) {
needs_reduce = true; needs_reduce = true;
} }
} else { } else {

View File

@ -1230,10 +1230,22 @@ def _log_suppressed_dde(a: SymBool, assumed_value: bool) -> None:
# of various framework code. Those would be used in situations you prefer to guard and know # of various framework code. Those would be used in situations you prefer to guard and know
# the result of the expression over not guarding, but in case you hit a data dependent error # the result of the expression over not guarding, but in case you hit a data dependent error
# you are ok with just returning true or false. # you are ok with just returning true or false.
# Some reasons you might be ok with returning true/false instead could be: #
# (1) It's an optimization/additional check I do not want to fail for not performing it. # When to use this?
# (2) I am willing to deviate from the normal semantics when I have unbacked for the # (1) If you can use a higher level combinator prefer using those instead, they are definitely safe (modulo short-circuiting).
# benefit of not failing. #
# (2) It can be used if the program would behave equivalently if _guard_or returned true or false.
# Many inductor optimizations fall in this bracket for example.
#
# (3) Finally, it's even be OK if the program wouldn't behave equivalently, so long as the
# change is semantics preserving. It can be semantics preserving if the program errors in more
# cases than it did previously (but otherwise behaves identically), or if it changes some quantity
# in a way that doesn't matter (e.g., strides often fall in this bucket.)
#
# (4) Specialize for the general case and add a runtime assertion that would fail during
# runtime if the conditions for the general case are not satisfied. Examples for this are;
# assuming expand/reshape inputs are not -1. or assuming the non-broadcasting path.
#
def _guard_or(a: BoolLikeType, default: bool) -> bool: def _guard_or(a: BoolLikeType, default: bool) -> bool:
if not isinstance(a, SymBool): if not isinstance(a, SymBool):
assert isinstance(a, bool) assert isinstance(a, bool)
@ -1272,48 +1284,6 @@ def guard_or_true(a: BoolLikeType) -> bool:
return _guard_or(a, True) return _guard_or(a, True)
def definitely_true(a: BoolLikeType) -> bool:
"""
Returns True only if we can tell that a is True, possibly introducing
a guard in the process. If a depends on some unbacked SymInt, we may
return False even though there may exist a possible value of the SymInt
that would cause the expression to return True.
When is it appropriate to use definitely_true? First, if you can use
a higher level combinator prefer using those instead, they are definitely
safe (modulo short-circuiting).
Second, it can be used if the program would behave equivalently if
definitely_true always returned False. Finally, it even
be OK if the program wouldn't behave equivalently, so long as the
change is semantics preserving. It can be semantics preserving if
the program errors in more cases than it did previously (but otherwise
behaves identically), or if it changes some quantity in a way that
doesn't matter (e.g., strides often fall in this bucket.)
"""
if isinstance(a, SymBool):
if a.node.has_hint():
return guard_bool(a)
else:
return False
return bool(a)
def definitely_false(a: BoolLikeType) -> bool:
"""
Returns True only if we can tell that a is False, possibly introducing
a guard in the process. If a depends on some unbacked SymInt, we may
return False even though there may exist a possible value of the SymInt
that would cause the expression a to be False. See definitely_true
for more usage guidance.
"""
if isinstance(a, SymBool):
if a.node.has_hint():
return not guard_bool(a)
else:
return False
return not bool(a)
def _static_eval_sym_bool(x: SymBool) -> Optional[bool]: def _static_eval_sym_bool(x: SymBool) -> Optional[bool]:
assert isinstance(x, SymBool) assert isinstance(x, SymBool)
expr = x.node.expr expr = x.node.expr