mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
consolidate guard_or_x and definitely_x (#152463)
definitely_true is almost same as guard_or_false, the potential differences are not meaningful to a degree that justify the existence of both. same for definitely_false, it can be expressed with guard_or_true and guard_or_false. Pull Request resolved: https://github.com/pytorch/pytorch/pull/152463 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
72337bdcf2
commit
376529c78b
@ -419,11 +419,4 @@ inline SymBool sym_ge(const SymInt& a, const SymInt& b) {
|
|||||||
return a.sym_ge(b);
|
return a.sym_ge(b);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline bool definitely_true(
|
|
||||||
const c10::SymBool& b,
|
|
||||||
const char* file,
|
|
||||||
int64_t line) {
|
|
||||||
return b.has_hint() && b.guard_bool(file, line);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace c10
|
} // namespace c10
|
||||||
|
@ -128,11 +128,11 @@ DEFINE_SYMBOOL_COMPUTE(compute_non_overlapping_and_dense, is_non_overlapping_and
|
|||||||
|
|
||||||
SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_dim4() const {
|
SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_dim4() const {
|
||||||
init_is_contiguous();
|
init_is_contiguous();
|
||||||
if (definitely_true(is_contiguous(), __FILE__, __LINE__)) {
|
if (guard_or_false(is_contiguous(), __FILE__, __LINE__)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
init_is_channels_last_contiguous();
|
init_is_channels_last_contiguous();
|
||||||
if (definitely_true(is_channels_last_contiguous(), __FILE__, __LINE__)) {
|
if (guard_or_false(is_channels_last_contiguous(), __FILE__, __LINE__)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return is_contiguous() | is_channels_last_contiguous() |
|
return is_contiguous() | is_channels_last_contiguous() |
|
||||||
@ -141,7 +141,7 @@ SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_dim4() const {
|
|||||||
|
|
||||||
SymBool SymbolicShapeMeta::compute_channels_last_contiguous_3d_dim5() const {
|
SymBool SymbolicShapeMeta::compute_channels_last_contiguous_3d_dim5() const {
|
||||||
init_is_channels_last_contiguous();
|
init_is_channels_last_contiguous();
|
||||||
if (definitely_true(is_channels_last_contiguous(), __FILE__, __LINE__)) {
|
if (guard_or_false(is_channels_last_contiguous(), __FILE__, __LINE__)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return ~is_channels_last_contiguous() & compute_channels_last_contiguous_3d();
|
return ~is_channels_last_contiguous() & compute_channels_last_contiguous_3d();
|
||||||
@ -149,7 +149,7 @@ SymBool SymbolicShapeMeta::compute_channels_last_contiguous_3d_dim5() const {
|
|||||||
|
|
||||||
SymBool SymbolicShapeMeta::compute_channels_last_2d_dim5() const {
|
SymBool SymbolicShapeMeta::compute_channels_last_2d_dim5() const {
|
||||||
init_is_channels_last_3d_contiguous();
|
init_is_channels_last_3d_contiguous();
|
||||||
if (definitely_true(is_channels_last_3d_contiguous(), __FILE__, __LINE__)) {
|
if (guard_or_false(is_channels_last_3d_contiguous(), __FILE__, __LINE__)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return ~is_channels_last_3d_contiguous() &
|
return ~is_channels_last_3d_contiguous() &
|
||||||
@ -157,20 +157,20 @@ SymBool SymbolicShapeMeta::compute_channels_last_2d_dim5() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
SymBool SymbolicShapeMeta::compute_channels_last_3d_dim5() const {
|
SymBool SymbolicShapeMeta::compute_channels_last_3d_dim5() const {
|
||||||
if (definitely_true(is_channels_last(), __FILE__, __LINE__)) {
|
if (guard_or_false(is_channels_last(), __FILE__, __LINE__)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return ~is_channels_last() & compute_strides_like_channels_last_3d();
|
return ~is_channels_last() & compute_strides_like_channels_last_3d();
|
||||||
}
|
}
|
||||||
|
|
||||||
SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_dim5() const {
|
SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_dim5() const {
|
||||||
if (definitely_true(is_contiguous(), __FILE__, __LINE__)) {
|
if (guard_or_false(is_contiguous(), __FILE__, __LINE__)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (definitely_true(is_channels_last_contiguous(), __FILE__, __LINE__)) {
|
if (guard_or_false(is_channels_last_contiguous(), __FILE__, __LINE__)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (definitely_true(is_channels_last_3d_contiguous(), __FILE__, __LINE__)) {
|
if (guard_or_false(is_channels_last_3d_contiguous(), __FILE__, __LINE__)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return is_contiguous() | is_channels_last_contiguous() |
|
return is_contiguous() | is_channels_last_contiguous() |
|
||||||
@ -178,7 +178,7 @@ SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_dim5() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_anydim() const {
|
SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_anydim() const {
|
||||||
if (definitely_true(is_contiguous(), __FILE__, __LINE__)) {
|
if (guard_or_false(is_contiguous(), __FILE__, __LINE__)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return is_contiguous() | compute_non_overlapping_and_dense();
|
return is_contiguous() | compute_non_overlapping_and_dense();
|
||||||
|
@ -38,8 +38,6 @@ torch.fx.experimental.symbolic_shapes
|
|||||||
is_concrete_float
|
is_concrete_float
|
||||||
has_free_symbols
|
has_free_symbols
|
||||||
has_free_unbacked_symbols
|
has_free_unbacked_symbols
|
||||||
definitely_true
|
|
||||||
definitely_false
|
|
||||||
guard_or_true
|
guard_or_true
|
||||||
guard_or_false
|
guard_or_false
|
||||||
guard_size_oblivious
|
guard_size_oblivious
|
||||||
|
@ -2007,8 +2007,6 @@
|
|||||||
"cast_symbool_to_symint_guardless",
|
"cast_symbool_to_symint_guardless",
|
||||||
"constrain_range",
|
"constrain_range",
|
||||||
"constrain_unify",
|
"constrain_unify",
|
||||||
"definitely_false",
|
|
||||||
"definitely_true",
|
|
||||||
"guard_or_true",
|
"guard_or_true",
|
||||||
"guard_or_false",
|
"guard_or_false",
|
||||||
"error",
|
"error",
|
||||||
|
@ -17,11 +17,7 @@ from torch._logging import getArtifactLogger
|
|||||||
from torch._subclasses.fake_tensor import FakeTensor
|
from torch._subclasses.fake_tensor import FakeTensor
|
||||||
from torch._subclasses.functional_tensor import FunctionalTensor
|
from torch._subclasses.functional_tensor import FunctionalTensor
|
||||||
from torch._subclasses.meta_utils import is_sparse_any
|
from torch._subclasses.meta_utils import is_sparse_any
|
||||||
from torch.fx.experimental.symbolic_shapes import (
|
from torch.fx.experimental.symbolic_shapes import guard_or_false, sym_eq, SymIntEqByExpr
|
||||||
definitely_true,
|
|
||||||
sym_eq,
|
|
||||||
SymIntEqByExpr,
|
|
||||||
)
|
|
||||||
from torch.multiprocessing.reductions import StorageWeakRef
|
from torch.multiprocessing.reductions import StorageWeakRef
|
||||||
from torch.utils._python_dispatch import (
|
from torch.utils._python_dispatch import (
|
||||||
is_traceable_wrapper_subclass,
|
is_traceable_wrapper_subclass,
|
||||||
@ -317,13 +313,13 @@ def gen_alias_from_base(
|
|||||||
|
|
||||||
def has_same_metadata(t1, t2):
|
def has_same_metadata(t1, t2):
|
||||||
return (
|
return (
|
||||||
definitely_true(sym_eq(t1.size(), t2.size()))
|
guard_or_false(sym_eq(t1.size(), t2.size()))
|
||||||
and definitely_true(t1.layout == t2.layout)
|
and guard_or_false(t1.layout == t2.layout)
|
||||||
and (
|
and (
|
||||||
is_sparse_any(t1)
|
is_sparse_any(t1)
|
||||||
or (
|
or (
|
||||||
definitely_true(sym_eq(t1.stride(), t2.stride()))
|
guard_or_false(sym_eq(t1.stride(), t2.stride()))
|
||||||
and definitely_true(t1.storage_offset() == t2.storage_offset())
|
and guard_or_false(t1.storage_offset() == t2.storage_offset())
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
and t1.is_conj() == t2.is_conj()
|
and t1.is_conj() == t2.is_conj()
|
||||||
|
@ -29,7 +29,7 @@ from torch.fx.experimental.proxy_tensor import (
|
|||||||
maybe_enable_thunkify,
|
maybe_enable_thunkify,
|
||||||
)
|
)
|
||||||
from torch.fx.experimental.symbolic_shapes import (
|
from torch.fx.experimental.symbolic_shapes import (
|
||||||
definitely_false,
|
guard_or_true,
|
||||||
PropagateUnbackedSymInts,
|
PropagateUnbackedSymInts,
|
||||||
sym_eq,
|
sym_eq,
|
||||||
)
|
)
|
||||||
@ -221,12 +221,15 @@ def create_joint(fn: Callable, *, aot_config: AOTConfig) -> Any:
|
|||||||
# A bit sketchy, but fixes e.g. test_aot_autograd_exhaustive_matmul_cpu_float32
|
# A bit sketchy, but fixes e.g. test_aot_autograd_exhaustive_matmul_cpu_float32
|
||||||
# The issue is that we are sensitive to decomps that don't accurately maintain
|
# The issue is that we are sensitive to decomps that don't accurately maintain
|
||||||
# their output's _base.shape compared to eager mode, and this helps mitigate a bit.
|
# their output's _base.shape compared to eager mode, and this helps mitigate a bit.
|
||||||
# The not definitely_false is also sketchy; if unbacked
|
# The guard_or_true also sketchy; if unbacked
|
||||||
# symints are involved, we're just going to assume that the
|
# symints are involved, we're just going to assume that the
|
||||||
# decomps setup the base shape correctly
|
# decomps setup the base shape correctly
|
||||||
|
|
||||||
|
# Return out if the result of out.shape==tangent.shape is unknown or known to be true.
|
||||||
|
# otherwise if its a known false return out.view(tangent.shape).
|
||||||
needed_outs.append(
|
needed_outs.append(
|
||||||
out
|
out
|
||||||
if not definitely_false(sym_eq(out.shape, tangent.shape))
|
if guard_or_true(sym_eq(out.shape, tangent.shape))
|
||||||
else out.view(tangent.shape)
|
else out.view(tangent.shape)
|
||||||
)
|
)
|
||||||
needed_tangents.append(tangent)
|
needed_tangents.append(tangent)
|
||||||
|
@ -33,7 +33,7 @@ from torch._prims_common import (
|
|||||||
type_to_dtype,
|
type_to_dtype,
|
||||||
)
|
)
|
||||||
from torch.fx.experimental.symbolic_shapes import (
|
from torch.fx.experimental.symbolic_shapes import (
|
||||||
definitely_true,
|
guard_or_false,
|
||||||
guard_size_oblivious,
|
guard_size_oblivious,
|
||||||
statically_known_true,
|
statically_known_true,
|
||||||
)
|
)
|
||||||
@ -305,8 +305,8 @@ def addmm(
|
|||||||
return alpha * out + beta * self
|
return alpha * out + beta * self
|
||||||
if (
|
if (
|
||||||
statically_known_true(mat1.size(0) == 1)
|
statically_known_true(mat1.size(0) == 1)
|
||||||
and definitely_true(mat2.size(0) <= 16)
|
and guard_or_false(mat2.size(0) <= 16)
|
||||||
and definitely_true(mat2.size(1) <= 16)
|
and guard_or_false(mat2.size(1) <= 16)
|
||||||
):
|
):
|
||||||
counters["inductor"]["decompose_addmm"] += 1
|
counters["inductor"]["decompose_addmm"] += 1
|
||||||
out = (mat1.T * mat2).sum(dim=0, keepdim=True)
|
out = (mat1.T * mat2).sum(dim=0, keepdim=True)
|
||||||
@ -336,7 +336,7 @@ def mm(
|
|||||||
and statically_known_true(self.size(0) > 0)
|
and statically_known_true(self.size(0) > 0)
|
||||||
and statically_known_true(input2.size(0) == 1)
|
and statically_known_true(input2.size(0) == 1)
|
||||||
and (self.dtype == input2.dtype)
|
and (self.dtype == input2.dtype)
|
||||||
and definitely_true((torch.numel(self) + torch.numel(input2)) <= 32)
|
and guard_or_false((torch.numel(self) + torch.numel(input2)) <= 32)
|
||||||
):
|
):
|
||||||
counters["inductor"]["decompose_mm"] += 1
|
counters["inductor"]["decompose_mm"] += 1
|
||||||
return torch.cat([self[i, :] * input2 for i in range(self.size(0))])
|
return torch.cat([self[i, :] * input2 for i in range(self.size(0))])
|
||||||
|
@ -924,7 +924,7 @@ def infer_size(shape: ShapeType, numel: int) -> tuple[int, ...]:
|
|||||||
Infers the size of a dim with size -1, if it exists.
|
Infers the size of a dim with size -1, if it exists.
|
||||||
Also checks that new shape is compatible with the number of elements.
|
Also checks that new shape is compatible with the number of elements.
|
||||||
"""
|
"""
|
||||||
from torch.fx.experimental.symbolic_shapes import definitely_true, guard_or_false
|
from torch.fx.experimental.symbolic_shapes import guard_or_false
|
||||||
|
|
||||||
dim = None
|
dim = None
|
||||||
newsize = 1
|
newsize = 1
|
||||||
@ -952,7 +952,7 @@ def infer_size(shape: ShapeType, numel: int) -> tuple[int, ...]:
|
|||||||
lambda: (
|
lambda: (
|
||||||
f"cannot reshape tensor of 0 elements into shape {list(shape)} because the "
|
f"cannot reshape tensor of 0 elements into shape {list(shape)} because the "
|
||||||
f"unspecified dimension size -1 can be any value and is ambiguous"
|
f"unspecified dimension size -1 can be any value and is ambiguous"
|
||||||
if definitely_true(numel == 0)
|
if guard_or_false(numel == 0)
|
||||||
else f"shape '{list(shape)}' is invalid for input of size {numel}"
|
else f"shape '{list(shape)}' is invalid for input of size {numel}"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -94,7 +94,7 @@ at::Tensor InputMetadata::maybe_reduce(
|
|||||||
if (TORCH_GUARD_SIZE_OBLIVIOUS(size.sym_eq(1))) {
|
if (TORCH_GUARD_SIZE_OBLIVIOUS(size.sym_eq(1))) {
|
||||||
// NB: we could short circuit this once needs_reduce is true but there's
|
// NB: we could short circuit this once needs_reduce is true but there's
|
||||||
// no point since the reduction function will guard on this anyway
|
// no point since the reduction function will guard on this anyway
|
||||||
if (!c10::definitely_true(size.sym_eq(target), __FILE__, __LINE__)) {
|
if (!c10::guard_or_false(size.sym_eq(target), __FILE__, __LINE__)) {
|
||||||
needs_reduce = true;
|
needs_reduce = true;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -1230,10 +1230,22 @@ def _log_suppressed_dde(a: SymBool, assumed_value: bool) -> None:
|
|||||||
# of various framework code. Those would be used in situations you prefer to guard and know
|
# of various framework code. Those would be used in situations you prefer to guard and know
|
||||||
# the result of the expression over not guarding, but in case you hit a data dependent error
|
# the result of the expression over not guarding, but in case you hit a data dependent error
|
||||||
# you are ok with just returning true or false.
|
# you are ok with just returning true or false.
|
||||||
# Some reasons you might be ok with returning true/false instead could be:
|
#
|
||||||
# (1) It's an optimization/additional check I do not want to fail for not performing it.
|
# When to use this?
|
||||||
# (2) I am willing to deviate from the normal semantics when I have unbacked for the
|
# (1) If you can use a higher level combinator prefer using those instead, they are definitely safe (modulo short-circuiting).
|
||||||
# benefit of not failing.
|
#
|
||||||
|
# (2) It can be used if the program would behave equivalently if _guard_or returned true or false.
|
||||||
|
# Many inductor optimizations fall in this bracket for example.
|
||||||
|
#
|
||||||
|
# (3) Finally, it's even be OK if the program wouldn't behave equivalently, so long as the
|
||||||
|
# change is semantics preserving. It can be semantics preserving if the program errors in more
|
||||||
|
# cases than it did previously (but otherwise behaves identically), or if it changes some quantity
|
||||||
|
# in a way that doesn't matter (e.g., strides often fall in this bucket.)
|
||||||
|
#
|
||||||
|
# (4) Specialize for the general case and add a runtime assertion that would fail during
|
||||||
|
# runtime if the conditions for the general case are not satisfied. Examples for this are;
|
||||||
|
# assuming expand/reshape inputs are not -1. or assuming the non-broadcasting path.
|
||||||
|
#
|
||||||
def _guard_or(a: BoolLikeType, default: bool) -> bool:
|
def _guard_or(a: BoolLikeType, default: bool) -> bool:
|
||||||
if not isinstance(a, SymBool):
|
if not isinstance(a, SymBool):
|
||||||
assert isinstance(a, bool)
|
assert isinstance(a, bool)
|
||||||
@ -1272,48 +1284,6 @@ def guard_or_true(a: BoolLikeType) -> bool:
|
|||||||
return _guard_or(a, True)
|
return _guard_or(a, True)
|
||||||
|
|
||||||
|
|
||||||
def definitely_true(a: BoolLikeType) -> bool:
|
|
||||||
"""
|
|
||||||
Returns True only if we can tell that a is True, possibly introducing
|
|
||||||
a guard in the process. If a depends on some unbacked SymInt, we may
|
|
||||||
return False even though there may exist a possible value of the SymInt
|
|
||||||
that would cause the expression to return True.
|
|
||||||
|
|
||||||
When is it appropriate to use definitely_true? First, if you can use
|
|
||||||
a higher level combinator prefer using those instead, they are definitely
|
|
||||||
safe (modulo short-circuiting).
|
|
||||||
Second, it can be used if the program would behave equivalently if
|
|
||||||
definitely_true always returned False. Finally, it even
|
|
||||||
be OK if the program wouldn't behave equivalently, so long as the
|
|
||||||
change is semantics preserving. It can be semantics preserving if
|
|
||||||
the program errors in more cases than it did previously (but otherwise
|
|
||||||
behaves identically), or if it changes some quantity in a way that
|
|
||||||
doesn't matter (e.g., strides often fall in this bucket.)
|
|
||||||
"""
|
|
||||||
if isinstance(a, SymBool):
|
|
||||||
if a.node.has_hint():
|
|
||||||
return guard_bool(a)
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
return bool(a)
|
|
||||||
|
|
||||||
|
|
||||||
def definitely_false(a: BoolLikeType) -> bool:
|
|
||||||
"""
|
|
||||||
Returns True only if we can tell that a is False, possibly introducing
|
|
||||||
a guard in the process. If a depends on some unbacked SymInt, we may
|
|
||||||
return False even though there may exist a possible value of the SymInt
|
|
||||||
that would cause the expression a to be False. See definitely_true
|
|
||||||
for more usage guidance.
|
|
||||||
"""
|
|
||||||
if isinstance(a, SymBool):
|
|
||||||
if a.node.has_hint():
|
|
||||||
return not guard_bool(a)
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
return not bool(a)
|
|
||||||
|
|
||||||
|
|
||||||
def _static_eval_sym_bool(x: SymBool) -> Optional[bool]:
|
def _static_eval_sym_bool(x: SymBool) -> Optional[bool]:
|
||||||
assert isinstance(x, SymBool)
|
assert isinstance(x, SymBool)
|
||||||
expr = x.node.expr
|
expr = x.node.expr
|
||||||
|
Reference in New Issue
Block a user