mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
consolidate guard_or_x and definitely_x (#152463)
definitely_true is almost same as guard_or_false, the potential differences are not meaningful to a degree that justify the existence of both. same for definitely_false, it can be expressed with guard_or_true and guard_or_false. Pull Request resolved: https://github.com/pytorch/pytorch/pull/152463 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
72337bdcf2
commit
376529c78b
@ -419,11 +419,4 @@ inline SymBool sym_ge(const SymInt& a, const SymInt& b) {
|
||||
return a.sym_ge(b);
|
||||
}
|
||||
|
||||
inline bool definitely_true(
|
||||
const c10::SymBool& b,
|
||||
const char* file,
|
||||
int64_t line) {
|
||||
return b.has_hint() && b.guard_bool(file, line);
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
@ -128,11 +128,11 @@ DEFINE_SYMBOOL_COMPUTE(compute_non_overlapping_and_dense, is_non_overlapping_and
|
||||
|
||||
SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_dim4() const {
|
||||
init_is_contiguous();
|
||||
if (definitely_true(is_contiguous(), __FILE__, __LINE__)) {
|
||||
if (guard_or_false(is_contiguous(), __FILE__, __LINE__)) {
|
||||
return true;
|
||||
}
|
||||
init_is_channels_last_contiguous();
|
||||
if (definitely_true(is_channels_last_contiguous(), __FILE__, __LINE__)) {
|
||||
if (guard_or_false(is_channels_last_contiguous(), __FILE__, __LINE__)) {
|
||||
return true;
|
||||
}
|
||||
return is_contiguous() | is_channels_last_contiguous() |
|
||||
@ -141,7 +141,7 @@ SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_dim4() const {
|
||||
|
||||
SymBool SymbolicShapeMeta::compute_channels_last_contiguous_3d_dim5() const {
|
||||
init_is_channels_last_contiguous();
|
||||
if (definitely_true(is_channels_last_contiguous(), __FILE__, __LINE__)) {
|
||||
if (guard_or_false(is_channels_last_contiguous(), __FILE__, __LINE__)) {
|
||||
return false;
|
||||
}
|
||||
return ~is_channels_last_contiguous() & compute_channels_last_contiguous_3d();
|
||||
@ -149,7 +149,7 @@ SymBool SymbolicShapeMeta::compute_channels_last_contiguous_3d_dim5() const {
|
||||
|
||||
SymBool SymbolicShapeMeta::compute_channels_last_2d_dim5() const {
|
||||
init_is_channels_last_3d_contiguous();
|
||||
if (definitely_true(is_channels_last_3d_contiguous(), __FILE__, __LINE__)) {
|
||||
if (guard_or_false(is_channels_last_3d_contiguous(), __FILE__, __LINE__)) {
|
||||
return false;
|
||||
}
|
||||
return ~is_channels_last_3d_contiguous() &
|
||||
@ -157,20 +157,20 @@ SymBool SymbolicShapeMeta::compute_channels_last_2d_dim5() const {
|
||||
}
|
||||
|
||||
SymBool SymbolicShapeMeta::compute_channels_last_3d_dim5() const {
|
||||
if (definitely_true(is_channels_last(), __FILE__, __LINE__)) {
|
||||
if (guard_or_false(is_channels_last(), __FILE__, __LINE__)) {
|
||||
return false;
|
||||
}
|
||||
return ~is_channels_last() & compute_strides_like_channels_last_3d();
|
||||
}
|
||||
|
||||
SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_dim5() const {
|
||||
if (definitely_true(is_contiguous(), __FILE__, __LINE__)) {
|
||||
if (guard_or_false(is_contiguous(), __FILE__, __LINE__)) {
|
||||
return true;
|
||||
}
|
||||
if (definitely_true(is_channels_last_contiguous(), __FILE__, __LINE__)) {
|
||||
if (guard_or_false(is_channels_last_contiguous(), __FILE__, __LINE__)) {
|
||||
return true;
|
||||
}
|
||||
if (definitely_true(is_channels_last_3d_contiguous(), __FILE__, __LINE__)) {
|
||||
if (guard_or_false(is_channels_last_3d_contiguous(), __FILE__, __LINE__)) {
|
||||
return true;
|
||||
}
|
||||
return is_contiguous() | is_channels_last_contiguous() |
|
||||
@ -178,7 +178,7 @@ SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_dim5() const {
|
||||
}
|
||||
|
||||
SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_anydim() const {
|
||||
if (definitely_true(is_contiguous(), __FILE__, __LINE__)) {
|
||||
if (guard_or_false(is_contiguous(), __FILE__, __LINE__)) {
|
||||
return true;
|
||||
}
|
||||
return is_contiguous() | compute_non_overlapping_and_dense();
|
||||
|
@ -38,8 +38,6 @@ torch.fx.experimental.symbolic_shapes
|
||||
is_concrete_float
|
||||
has_free_symbols
|
||||
has_free_unbacked_symbols
|
||||
definitely_true
|
||||
definitely_false
|
||||
guard_or_true
|
||||
guard_or_false
|
||||
guard_size_oblivious
|
||||
|
@ -2007,8 +2007,6 @@
|
||||
"cast_symbool_to_symint_guardless",
|
||||
"constrain_range",
|
||||
"constrain_unify",
|
||||
"definitely_false",
|
||||
"definitely_true",
|
||||
"guard_or_true",
|
||||
"guard_or_false",
|
||||
"error",
|
||||
|
@ -17,11 +17,7 @@ from torch._logging import getArtifactLogger
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
from torch._subclasses.functional_tensor import FunctionalTensor
|
||||
from torch._subclasses.meta_utils import is_sparse_any
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
definitely_true,
|
||||
sym_eq,
|
||||
SymIntEqByExpr,
|
||||
)
|
||||
from torch.fx.experimental.symbolic_shapes import guard_or_false, sym_eq, SymIntEqByExpr
|
||||
from torch.multiprocessing.reductions import StorageWeakRef
|
||||
from torch.utils._python_dispatch import (
|
||||
is_traceable_wrapper_subclass,
|
||||
@ -317,13 +313,13 @@ def gen_alias_from_base(
|
||||
|
||||
def has_same_metadata(t1, t2):
|
||||
return (
|
||||
definitely_true(sym_eq(t1.size(), t2.size()))
|
||||
and definitely_true(t1.layout == t2.layout)
|
||||
guard_or_false(sym_eq(t1.size(), t2.size()))
|
||||
and guard_or_false(t1.layout == t2.layout)
|
||||
and (
|
||||
is_sparse_any(t1)
|
||||
or (
|
||||
definitely_true(sym_eq(t1.stride(), t2.stride()))
|
||||
and definitely_true(t1.storage_offset() == t2.storage_offset())
|
||||
guard_or_false(sym_eq(t1.stride(), t2.stride()))
|
||||
and guard_or_false(t1.storage_offset() == t2.storage_offset())
|
||||
)
|
||||
)
|
||||
and t1.is_conj() == t2.is_conj()
|
||||
|
@ -29,7 +29,7 @@ from torch.fx.experimental.proxy_tensor import (
|
||||
maybe_enable_thunkify,
|
||||
)
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
definitely_false,
|
||||
guard_or_true,
|
||||
PropagateUnbackedSymInts,
|
||||
sym_eq,
|
||||
)
|
||||
@ -221,12 +221,15 @@ def create_joint(fn: Callable, *, aot_config: AOTConfig) -> Any:
|
||||
# A bit sketchy, but fixes e.g. test_aot_autograd_exhaustive_matmul_cpu_float32
|
||||
# The issue is that we are sensitive to decomps that don't accurately maintain
|
||||
# their output's _base.shape compared to eager mode, and this helps mitigate a bit.
|
||||
# The not definitely_false is also sketchy; if unbacked
|
||||
# The guard_or_true also sketchy; if unbacked
|
||||
# symints are involved, we're just going to assume that the
|
||||
# decomps setup the base shape correctly
|
||||
|
||||
# Return out if the result of out.shape==tangent.shape is unknown or known to be true.
|
||||
# otherwise if its a known false return out.view(tangent.shape).
|
||||
needed_outs.append(
|
||||
out
|
||||
if not definitely_false(sym_eq(out.shape, tangent.shape))
|
||||
if guard_or_true(sym_eq(out.shape, tangent.shape))
|
||||
else out.view(tangent.shape)
|
||||
)
|
||||
needed_tangents.append(tangent)
|
||||
|
@ -33,7 +33,7 @@ from torch._prims_common import (
|
||||
type_to_dtype,
|
||||
)
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
definitely_true,
|
||||
guard_or_false,
|
||||
guard_size_oblivious,
|
||||
statically_known_true,
|
||||
)
|
||||
@ -305,8 +305,8 @@ def addmm(
|
||||
return alpha * out + beta * self
|
||||
if (
|
||||
statically_known_true(mat1.size(0) == 1)
|
||||
and definitely_true(mat2.size(0) <= 16)
|
||||
and definitely_true(mat2.size(1) <= 16)
|
||||
and guard_or_false(mat2.size(0) <= 16)
|
||||
and guard_or_false(mat2.size(1) <= 16)
|
||||
):
|
||||
counters["inductor"]["decompose_addmm"] += 1
|
||||
out = (mat1.T * mat2).sum(dim=0, keepdim=True)
|
||||
@ -336,7 +336,7 @@ def mm(
|
||||
and statically_known_true(self.size(0) > 0)
|
||||
and statically_known_true(input2.size(0) == 1)
|
||||
and (self.dtype == input2.dtype)
|
||||
and definitely_true((torch.numel(self) + torch.numel(input2)) <= 32)
|
||||
and guard_or_false((torch.numel(self) + torch.numel(input2)) <= 32)
|
||||
):
|
||||
counters["inductor"]["decompose_mm"] += 1
|
||||
return torch.cat([self[i, :] * input2 for i in range(self.size(0))])
|
||||
|
@ -924,7 +924,7 @@ def infer_size(shape: ShapeType, numel: int) -> tuple[int, ...]:
|
||||
Infers the size of a dim with size -1, if it exists.
|
||||
Also checks that new shape is compatible with the number of elements.
|
||||
"""
|
||||
from torch.fx.experimental.symbolic_shapes import definitely_true, guard_or_false
|
||||
from torch.fx.experimental.symbolic_shapes import guard_or_false
|
||||
|
||||
dim = None
|
||||
newsize = 1
|
||||
@ -952,7 +952,7 @@ def infer_size(shape: ShapeType, numel: int) -> tuple[int, ...]:
|
||||
lambda: (
|
||||
f"cannot reshape tensor of 0 elements into shape {list(shape)} because the "
|
||||
f"unspecified dimension size -1 can be any value and is ambiguous"
|
||||
if definitely_true(numel == 0)
|
||||
if guard_or_false(numel == 0)
|
||||
else f"shape '{list(shape)}' is invalid for input of size {numel}"
|
||||
),
|
||||
)
|
||||
|
@ -94,7 +94,7 @@ at::Tensor InputMetadata::maybe_reduce(
|
||||
if (TORCH_GUARD_SIZE_OBLIVIOUS(size.sym_eq(1))) {
|
||||
// NB: we could short circuit this once needs_reduce is true but there's
|
||||
// no point since the reduction function will guard on this anyway
|
||||
if (!c10::definitely_true(size.sym_eq(target), __FILE__, __LINE__)) {
|
||||
if (!c10::guard_or_false(size.sym_eq(target), __FILE__, __LINE__)) {
|
||||
needs_reduce = true;
|
||||
}
|
||||
} else {
|
||||
|
@ -1230,10 +1230,22 @@ def _log_suppressed_dde(a: SymBool, assumed_value: bool) -> None:
|
||||
# of various framework code. Those would be used in situations you prefer to guard and know
|
||||
# the result of the expression over not guarding, but in case you hit a data dependent error
|
||||
# you are ok with just returning true or false.
|
||||
# Some reasons you might be ok with returning true/false instead could be:
|
||||
# (1) It's an optimization/additional check I do not want to fail for not performing it.
|
||||
# (2) I am willing to deviate from the normal semantics when I have unbacked for the
|
||||
# benefit of not failing.
|
||||
#
|
||||
# When to use this?
|
||||
# (1) If you can use a higher level combinator prefer using those instead, they are definitely safe (modulo short-circuiting).
|
||||
#
|
||||
# (2) It can be used if the program would behave equivalently if _guard_or returned true or false.
|
||||
# Many inductor optimizations fall in this bracket for example.
|
||||
#
|
||||
# (3) Finally, it's even be OK if the program wouldn't behave equivalently, so long as the
|
||||
# change is semantics preserving. It can be semantics preserving if the program errors in more
|
||||
# cases than it did previously (but otherwise behaves identically), or if it changes some quantity
|
||||
# in a way that doesn't matter (e.g., strides often fall in this bucket.)
|
||||
#
|
||||
# (4) Specialize for the general case and add a runtime assertion that would fail during
|
||||
# runtime if the conditions for the general case are not satisfied. Examples for this are;
|
||||
# assuming expand/reshape inputs are not -1. or assuming the non-broadcasting path.
|
||||
#
|
||||
def _guard_or(a: BoolLikeType, default: bool) -> bool:
|
||||
if not isinstance(a, SymBool):
|
||||
assert isinstance(a, bool)
|
||||
@ -1272,48 +1284,6 @@ def guard_or_true(a: BoolLikeType) -> bool:
|
||||
return _guard_or(a, True)
|
||||
|
||||
|
||||
def definitely_true(a: BoolLikeType) -> bool:
|
||||
"""
|
||||
Returns True only if we can tell that a is True, possibly introducing
|
||||
a guard in the process. If a depends on some unbacked SymInt, we may
|
||||
return False even though there may exist a possible value of the SymInt
|
||||
that would cause the expression to return True.
|
||||
|
||||
When is it appropriate to use definitely_true? First, if you can use
|
||||
a higher level combinator prefer using those instead, they are definitely
|
||||
safe (modulo short-circuiting).
|
||||
Second, it can be used if the program would behave equivalently if
|
||||
definitely_true always returned False. Finally, it even
|
||||
be OK if the program wouldn't behave equivalently, so long as the
|
||||
change is semantics preserving. It can be semantics preserving if
|
||||
the program errors in more cases than it did previously (but otherwise
|
||||
behaves identically), or if it changes some quantity in a way that
|
||||
doesn't matter (e.g., strides often fall in this bucket.)
|
||||
"""
|
||||
if isinstance(a, SymBool):
|
||||
if a.node.has_hint():
|
||||
return guard_bool(a)
|
||||
else:
|
||||
return False
|
||||
return bool(a)
|
||||
|
||||
|
||||
def definitely_false(a: BoolLikeType) -> bool:
|
||||
"""
|
||||
Returns True only if we can tell that a is False, possibly introducing
|
||||
a guard in the process. If a depends on some unbacked SymInt, we may
|
||||
return False even though there may exist a possible value of the SymInt
|
||||
that would cause the expression a to be False. See definitely_true
|
||||
for more usage guidance.
|
||||
"""
|
||||
if isinstance(a, SymBool):
|
||||
if a.node.has_hint():
|
||||
return not guard_bool(a)
|
||||
else:
|
||||
return False
|
||||
return not bool(a)
|
||||
|
||||
|
||||
def _static_eval_sym_bool(x: SymBool) -> Optional[bool]:
|
||||
assert isinstance(x, SymBool)
|
||||
expr = x.node.expr
|
||||
|
Reference in New Issue
Block a user