mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[HigherOrderOp] expose torch.cond (#110293)"
This reverts commit 601f872831649bccf1069ac59b2ecfd0895a88e3. Reverted https://github.com/pytorch/pytorch/pull/110293 on behalf of https://github.com/ydwu4 due to Sorry, didn't check the error carefully on the PR. A doc error is related to this pr ([comment](https://github.com/pytorch/pytorch/pull/110293#issuecomment-1751176719))
This commit is contained in:
@ -501,7 +501,7 @@ Graph breaks can also be encountered on data-dependent control flow (``if
|
||||
x.shape[0] > 2``) when shapes are not being specialized, as a tracing compiler cannot
|
||||
possibly deal with without generating code for a combinatorially exploding
|
||||
number of paths. In such cases, users will need to rewrite their code using
|
||||
special control flow operators. Currently, we support :ref:`torch.cond <cond>`
|
||||
special control flow operators. Currently, we support :ref:`torch.cond <control_flow_cond>`
|
||||
to express if-else like control flow (more coming soon!).
|
||||
|
||||
Data-Dependent Accesses
|
||||
@ -540,7 +540,7 @@ Read More
|
||||
torch.compiler_transformations
|
||||
torch.compiler_ir
|
||||
generated/exportdb/index
|
||||
cond
|
||||
control_flow_cond
|
||||
|
||||
.. toctree::
|
||||
:caption: Deep Dive for PyTorch Developers
|
||||
|
@ -718,17 +718,6 @@ Export Path
|
||||
export
|
||||
generated/exportdb/index
|
||||
|
||||
Control Flow
|
||||
------------
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
.. warning::
|
||||
This feature is a prototype and may have compatibility breaking changes in the future.
|
||||
|
||||
cond
|
||||
|
||||
Optimizations
|
||||
-------------
|
||||
.. autosummary::
|
||||
|
@ -1,4 +1,6 @@
|
||||
from torch import cond # noqa: F401
|
||||
from torch._higher_order_ops.cond import UnsupportedAliasMutationException # noqa: F401
|
||||
from torch._higher_order_ops.cond import ( # noqa: F401
|
||||
cond,
|
||||
UnsupportedAliasMutationException,
|
||||
)
|
||||
|
||||
from ._map import map # noqa: F401
|
||||
|
@ -56,7 +56,7 @@ __all__ = [
|
||||
'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat',
|
||||
'SymBool', 'sym_not',
|
||||
'sym_int', 'sym_float', 'sym_max', 'sym_min', 'compile', 'vmap',
|
||||
'export', 'autocast', 'cond',
|
||||
'export', 'autocast',
|
||||
]
|
||||
|
||||
################################################################################
|
||||
@ -986,7 +986,7 @@ def is_warn_always_enabled() -> builtins.bool:
|
||||
# These error checking functions must be kept consistent with their C++
|
||||
# equivalents. Their C++ equivalents are mentioned where applicable.
|
||||
|
||||
def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callable[[], str]): # noqa: F811
|
||||
def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callable[[], str]):
|
||||
if not isinstance(cond, (builtins.bool, torch.SymBool)):
|
||||
raise TypeError(f'cond must be a bool, but got {type(cond)}')
|
||||
|
||||
@ -1010,7 +1010,7 @@ def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callab
|
||||
|
||||
raise error_type(message_evaluated)
|
||||
|
||||
def _check(cond, message=None): # noqa: F811
|
||||
def _check(cond, message=None):
|
||||
r"""Throws error containing an optional message if the specified condition
|
||||
is False.
|
||||
|
||||
@ -1041,7 +1041,7 @@ def _check_is_size(i, message=None):
|
||||
_check(i >= 0, message)
|
||||
torch.fx.experimental.symbolic_shapes._advise_is_size(i)
|
||||
|
||||
def _check_index(cond, message=None): # noqa: F811
|
||||
def _check_index(cond, message=None):
|
||||
r"""Throws error containing an optional message if the specified condition
|
||||
is False.
|
||||
|
||||
@ -1058,7 +1058,7 @@ def _check_index(cond, message=None): # noqa: F811
|
||||
"""
|
||||
_check_with(IndexError, cond, message)
|
||||
|
||||
def _check_value(cond, message=None): # noqa: F811
|
||||
def _check_value(cond, message=None):
|
||||
r"""Throws error containing an optional message if the specified condition
|
||||
is False.
|
||||
|
||||
@ -1075,7 +1075,7 @@ def _check_value(cond, message=None): # noqa: F811
|
||||
"""
|
||||
_check_with(ValueError, cond, message)
|
||||
|
||||
def _check_type(cond, message=None): # noqa: F811
|
||||
def _check_type(cond, message=None):
|
||||
r"""Throws error containing an optional message if the specified condition
|
||||
is False.
|
||||
|
||||
@ -1092,7 +1092,7 @@ def _check_type(cond, message=None): # noqa: F811
|
||||
"""
|
||||
_check_with(TypeError, cond, message)
|
||||
|
||||
def _check_not_implemented(cond, message=None): # noqa: F811
|
||||
def _check_not_implemented(cond, message=None):
|
||||
r"""Throws error containing an optional message if the specified condition
|
||||
is False.
|
||||
|
||||
@ -1109,7 +1109,7 @@ def _check_not_implemented(cond, message=None): # noqa: F811
|
||||
"""
|
||||
_check_with(NotImplementedError, cond, message)
|
||||
|
||||
def _check_tensor_all_with(error_type, cond, message=None): # noqa: F811
|
||||
def _check_tensor_all_with(error_type, cond, message=None):
|
||||
if not torch.is_tensor(cond):
|
||||
raise TypeError(f'cond must be a tensor, but got {type(cond)}')
|
||||
|
||||
@ -1120,7 +1120,7 @@ def _check_tensor_all_with(error_type, cond, message=None): # noqa: F811
|
||||
_check_with(error_type, cond._is_all_true().item(), message)
|
||||
|
||||
# C++ equivalent: `TORCH_CHECK_TENSOR_ALL`
|
||||
def _check_tensor_all(cond, message=None): # noqa: F811
|
||||
def _check_tensor_all(cond, message=None):
|
||||
r"""Throws error containing an optional message if the specified condition
|
||||
is False.
|
||||
|
||||
@ -1761,7 +1761,6 @@ def compile(model: Optional[Callable] = None, *,
|
||||
|
||||
from torch import export as export
|
||||
|
||||
from torch._higher_order_ops import cond
|
||||
|
||||
def _register_device_module(device_type, module):
|
||||
r"""Register an external runtime module of the specific :attr:`device_type`
|
||||
|
@ -215,7 +215,6 @@ def _allowed_function_ids():
|
||||
torch.func.vmap,
|
||||
deprecated_func.vmap,
|
||||
torch.nn.functional.triplet_margin_with_distance_loss,
|
||||
torch.cond,
|
||||
):
|
||||
continue
|
||||
|
||||
|
@ -1 +0,0 @@
|
||||
from .cond import cond
|
||||
|
@ -8,7 +8,8 @@ import torch.fx.traceback as fx_traceback
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
from torch._C import DispatchKey
|
||||
from torch._functorch.utils import exposed_in
|
||||
from torch._dynamo.exc import CondOpArgsMismatchError
|
||||
from torch._dynamo.utils import disable_cache_limit
|
||||
|
||||
from torch._higher_order_ops.utils import autograd_not_implemented
|
||||
from torch._ops import HigherOrderOperator
|
||||
@ -41,7 +42,6 @@ class UnsupportedAliasMutationException(RuntimeError):
|
||||
reason: str
|
||||
|
||||
|
||||
@exposed_in("torch")
|
||||
def cond(pred, true_fn, false_fn, operands):
|
||||
r"""
|
||||
Conditionally applies `true_fn` or `false_fn`.
|
||||
@ -142,7 +142,7 @@ def cond(pred, true_fn, false_fn, operands):
|
||||
raise RuntimeError("torch.cond requires dynamo support.")
|
||||
|
||||
with _set_compilation_env():
|
||||
with torch._dynamo.utils.disable_cache_limit():
|
||||
with disable_cache_limit():
|
||||
return torch.compile(cond_op, backend="eager", fullgraph=True)(
|
||||
pred, true_fn, false_fn, operands
|
||||
)
|
||||
@ -198,7 +198,7 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
|
||||
flat_true_outs, _ = pytree.tree_flatten(true_outs)
|
||||
flat_false_outs, _ = pytree.tree_flatten(false_outs)
|
||||
if len(flat_true_outs) != len(flat_false_outs):
|
||||
raise torch._dynamo.exc.CondOpArgsMismatchError(
|
||||
raise CondOpArgsMismatchError(
|
||||
f"Expected to return same number of outputs but got:"
|
||||
f"\n {true_fn.__name__} returns {len(flat_true_outs)} item(s)"
|
||||
f"\n {false_fn.__name__} returns {len(flat_false_outs)} item(s)"
|
||||
@ -208,7 +208,7 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
|
||||
true_out = flat_true_outs[i]
|
||||
false_out = flat_false_outs[i]
|
||||
if true_out.meta["tensor_meta"] != false_out.meta["tensor_meta"]:
|
||||
raise torch._dynamo.exc.CondOpArgsMismatchError(
|
||||
raise CondOpArgsMismatchError(
|
||||
f"Expected each tensor to have same metadata but got:"
|
||||
f"\n {true_fn.__name__} returns {true_out.meta['tensor_meta']}"
|
||||
f"\n {false_fn.__name__} returns {false_out.meta['tensor_meta']}"
|
||||
@ -291,7 +291,7 @@ def cond_fake_tensor_mode(mode, pred, true_fn, false_fn, operands):
|
||||
true_meta = _extract_tensor_metadata(true_out)
|
||||
false_meta = _extract_tensor_metadata(false_out)
|
||||
if true_meta != false_meta:
|
||||
raise torch._dynamo.exc.CondOpArgsMismatchError(
|
||||
raise CondOpArgsMismatchError(
|
||||
f"Expected each tensor to have same metadata but got:"
|
||||
f"\n {true_fn.__name__} returns {true_meta}"
|
||||
f"\n {false_fn.__name__} returns {false_meta}"
|
||||
|
@ -297,7 +297,6 @@ def get_ignored_functions() -> Set[Callable]:
|
||||
torch.set_vital,
|
||||
torch.read_vitals,
|
||||
torch.vmap,
|
||||
torch.cond,
|
||||
torch.frombuffer,
|
||||
torch.asarray,
|
||||
torch._functional_sym_constrain_range,
|
||||
|
Reference in New Issue
Block a user