Revert "[HigherOrderOp] expose torch.cond (#110293)"

This reverts commit 601f872831649bccf1069ac59b2ecfd0895a88e3.

Reverted https://github.com/pytorch/pytorch/pull/110293 on behalf of https://github.com/ydwu4 due to Sorry, didn't check the error carefully on the PR. A doc error is related to this pr ([comment](https://github.com/pytorch/pytorch/pull/110293#issuecomment-1751176719))
This commit is contained in:
PyTorch MergeBot
2023-10-06 17:44:17 +00:00
parent e75f2e2ea1
commit 576b80d23e
9 changed files with 21 additions and 34 deletions

View File

@ -501,7 +501,7 @@ Graph breaks can also be encountered on data-dependent control flow (``if
x.shape[0] > 2``) when shapes are not being specialized, as a tracing compiler cannot
possibly deal with without generating code for a combinatorially exploding
number of paths. In such cases, users will need to rewrite their code using
special control flow operators. Currently, we support :ref:`torch.cond <cond>`
special control flow operators. Currently, we support :ref:`torch.cond <control_flow_cond>`
to express if-else like control flow (more coming soon!).
Data-Dependent Accesses
@ -540,7 +540,7 @@ Read More
torch.compiler_transformations
torch.compiler_ir
generated/exportdb/index
cond
control_flow_cond
.. toctree::
:caption: Deep Dive for PyTorch Developers

View File

@ -718,17 +718,6 @@ Export Path
export
generated/exportdb/index
Control Flow
------------
.. autosummary::
:toctree: generated
:nosignatures:
.. warning::
This feature is a prototype and may have compatibility breaking changes in the future.
cond
Optimizations
-------------
.. autosummary::

View File

@ -1,4 +1,6 @@
from torch import cond # noqa: F401
from torch._higher_order_ops.cond import UnsupportedAliasMutationException # noqa: F401
from torch._higher_order_ops.cond import ( # noqa: F401
cond,
UnsupportedAliasMutationException,
)
from ._map import map # noqa: F401

View File

@ -56,7 +56,7 @@ __all__ = [
'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat',
'SymBool', 'sym_not',
'sym_int', 'sym_float', 'sym_max', 'sym_min', 'compile', 'vmap',
'export', 'autocast', 'cond',
'export', 'autocast',
]
################################################################################
@ -986,7 +986,7 @@ def is_warn_always_enabled() -> builtins.bool:
# These error checking functions must be kept consistent with their C++
# equivalents. Their C++ equivalents are mentioned where applicable.
def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callable[[], str]): # noqa: F811
def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callable[[], str]):
if not isinstance(cond, (builtins.bool, torch.SymBool)):
raise TypeError(f'cond must be a bool, but got {type(cond)}')
@ -1010,7 +1010,7 @@ def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callab
raise error_type(message_evaluated)
def _check(cond, message=None): # noqa: F811
def _check(cond, message=None):
r"""Throws error containing an optional message if the specified condition
is False.
@ -1041,7 +1041,7 @@ def _check_is_size(i, message=None):
_check(i >= 0, message)
torch.fx.experimental.symbolic_shapes._advise_is_size(i)
def _check_index(cond, message=None): # noqa: F811
def _check_index(cond, message=None):
r"""Throws error containing an optional message if the specified condition
is False.
@ -1058,7 +1058,7 @@ def _check_index(cond, message=None): # noqa: F811
"""
_check_with(IndexError, cond, message)
def _check_value(cond, message=None): # noqa: F811
def _check_value(cond, message=None):
r"""Throws error containing an optional message if the specified condition
is False.
@ -1075,7 +1075,7 @@ def _check_value(cond, message=None): # noqa: F811
"""
_check_with(ValueError, cond, message)
def _check_type(cond, message=None): # noqa: F811
def _check_type(cond, message=None):
r"""Throws error containing an optional message if the specified condition
is False.
@ -1092,7 +1092,7 @@ def _check_type(cond, message=None): # noqa: F811
"""
_check_with(TypeError, cond, message)
def _check_not_implemented(cond, message=None): # noqa: F811
def _check_not_implemented(cond, message=None):
r"""Throws error containing an optional message if the specified condition
is False.
@ -1109,7 +1109,7 @@ def _check_not_implemented(cond, message=None): # noqa: F811
"""
_check_with(NotImplementedError, cond, message)
def _check_tensor_all_with(error_type, cond, message=None): # noqa: F811
def _check_tensor_all_with(error_type, cond, message=None):
if not torch.is_tensor(cond):
raise TypeError(f'cond must be a tensor, but got {type(cond)}')
@ -1120,7 +1120,7 @@ def _check_tensor_all_with(error_type, cond, message=None): # noqa: F811
_check_with(error_type, cond._is_all_true().item(), message)
# C++ equivalent: `TORCH_CHECK_TENSOR_ALL`
def _check_tensor_all(cond, message=None): # noqa: F811
def _check_tensor_all(cond, message=None):
r"""Throws error containing an optional message if the specified condition
is False.
@ -1761,7 +1761,6 @@ def compile(model: Optional[Callable] = None, *,
from torch import export as export
from torch._higher_order_ops import cond
def _register_device_module(device_type, module):
r"""Register an external runtime module of the specific :attr:`device_type`

View File

@ -215,7 +215,6 @@ def _allowed_function_ids():
torch.func.vmap,
deprecated_func.vmap,
torch.nn.functional.triplet_margin_with_distance_loss,
torch.cond,
):
continue

View File

@ -1 +0,0 @@
from .cond import cond

View File

@ -8,7 +8,8 @@ import torch.fx.traceback as fx_traceback
import torch.utils._pytree as pytree
from torch._C import DispatchKey
from torch._functorch.utils import exposed_in
from torch._dynamo.exc import CondOpArgsMismatchError
from torch._dynamo.utils import disable_cache_limit
from torch._higher_order_ops.utils import autograd_not_implemented
from torch._ops import HigherOrderOperator
@ -41,7 +42,6 @@ class UnsupportedAliasMutationException(RuntimeError):
reason: str
@exposed_in("torch")
def cond(pred, true_fn, false_fn, operands):
r"""
Conditionally applies `true_fn` or `false_fn`.
@ -142,7 +142,7 @@ def cond(pred, true_fn, false_fn, operands):
raise RuntimeError("torch.cond requires dynamo support.")
with _set_compilation_env():
with torch._dynamo.utils.disable_cache_limit():
with disable_cache_limit():
return torch.compile(cond_op, backend="eager", fullgraph=True)(
pred, true_fn, false_fn, operands
)
@ -198,7 +198,7 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
flat_true_outs, _ = pytree.tree_flatten(true_outs)
flat_false_outs, _ = pytree.tree_flatten(false_outs)
if len(flat_true_outs) != len(flat_false_outs):
raise torch._dynamo.exc.CondOpArgsMismatchError(
raise CondOpArgsMismatchError(
f"Expected to return same number of outputs but got:"
f"\n {true_fn.__name__} returns {len(flat_true_outs)} item(s)"
f"\n {false_fn.__name__} returns {len(flat_false_outs)} item(s)"
@ -208,7 +208,7 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
true_out = flat_true_outs[i]
false_out = flat_false_outs[i]
if true_out.meta["tensor_meta"] != false_out.meta["tensor_meta"]:
raise torch._dynamo.exc.CondOpArgsMismatchError(
raise CondOpArgsMismatchError(
f"Expected each tensor to have same metadata but got:"
f"\n {true_fn.__name__} returns {true_out.meta['tensor_meta']}"
f"\n {false_fn.__name__} returns {false_out.meta['tensor_meta']}"
@ -291,7 +291,7 @@ def cond_fake_tensor_mode(mode, pred, true_fn, false_fn, operands):
true_meta = _extract_tensor_metadata(true_out)
false_meta = _extract_tensor_metadata(false_out)
if true_meta != false_meta:
raise torch._dynamo.exc.CondOpArgsMismatchError(
raise CondOpArgsMismatchError(
f"Expected each tensor to have same metadata but got:"
f"\n {true_fn.__name__} returns {true_meta}"
f"\n {false_fn.__name__} returns {false_meta}"

View File

@ -297,7 +297,6 @@ def get_ignored_functions() -> Set[Callable]:
torch.set_vital,
torch.read_vitals,
torch.vmap,
torch.cond,
torch.frombuffer,
torch.asarray,
torch._functional_sym_constrain_range,