mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/106205 Approved by: https://github.com/albanD
354 lines
13 KiB
Python
354 lines
13 KiB
Python
from dataclasses import dataclass
|
|
|
|
import torch
|
|
|
|
import torch.utils._pytree as pytree
|
|
|
|
from torch._C import _ExcludeDispatchKeyGuard, DispatchKey, DispatchKeySet
|
|
from torch._dynamo.exc import CondOpArgsMismatchError
|
|
from torch._functorch.eager_transforms import (
|
|
_unwrap_all_tensors_from_functional,
|
|
_wrap_all_tensors_to_functional,
|
|
functionalize,
|
|
)
|
|
from torch._ops import HigherOrderOperator
|
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
from torch.fx.experimental.proxy_tensor import (
|
|
disable_proxy_modes_tracing,
|
|
make_fx,
|
|
ProxyTorchDispatchMode,
|
|
track_tensor_tree,
|
|
)
|
|
from torch.fx.passes.shape_prop import _extract_tensor_metadata
|
|
from torch.multiprocessing.reductions import StorageWeakRef
|
|
from torch.utils._python_dispatch import (
|
|
_get_current_dispatch_mode,
|
|
_pop_mode_temporarily,
|
|
)
|
|
from torch.utils._pytree import tree_flatten
|
|
|
|
|
|
@dataclass
|
|
class UnsupportedAliasMutationException(RuntimeError):
|
|
reason: str
|
|
|
|
|
|
"""
|
|
We're going to define a `cond` operation.
|
|
In order to do this, we need implementations for each of the dispatch keys.
|
|
"""
|
|
cond = HigherOrderOperator("cond")
|
|
|
|
|
|
def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
|
|
assert isinstance(
|
|
operands, (list, tuple)
|
|
), "Cond operands must be a list or tuple of tensors"
|
|
assert all(
|
|
isinstance(o, torch.Tensor) for o in operands
|
|
), "Cond operands must be a list of tensors"
|
|
|
|
with disable_proxy_modes_tracing():
|
|
true_graph = make_fx(true_fn)(*operands)
|
|
false_graph = make_fx(false_fn)(*operands)
|
|
|
|
true_outs = []
|
|
false_outs = []
|
|
for node in true_graph.graph.nodes:
|
|
if node.op == "output":
|
|
true_outs.extend(node.args)
|
|
|
|
for node in false_graph.graph.nodes:
|
|
if node.op == "output":
|
|
false_outs.extend(node.args)
|
|
|
|
flat_true_outs, _ = pytree.tree_flatten(true_outs)
|
|
flat_false_outs, _ = pytree.tree_flatten(false_outs)
|
|
if len(flat_true_outs) != len(flat_false_outs):
|
|
raise CondOpArgsMismatchError(
|
|
f"Expected to return same number of outputs but got:"
|
|
f"\n {true_fn.__name__} returns {len(flat_true_outs)} item(s)"
|
|
f"\n {false_fn.__name__} returns {len(flat_false_outs)} item(s)"
|
|
)
|
|
|
|
for i in range(0, len(flat_true_outs)):
|
|
true_out = flat_true_outs[i]
|
|
false_out = flat_false_outs[i]
|
|
if true_out.meta["tensor_meta"] != false_out.meta["tensor_meta"]:
|
|
raise CondOpArgsMismatchError(
|
|
f"Expected each tensor to have same metadata but got:"
|
|
f"\n {true_fn.__name__} returns {true_out.meta['tensor_meta']}"
|
|
f"\n {false_fn.__name__} returns {false_out.meta['tensor_meta']}"
|
|
)
|
|
|
|
# There are probably better ways - I know that create_arg has some self incrementing name
|
|
# magic to it, but since we explicitly have to get the name for register_module,
|
|
# I was not sure how to do that. This kinda simulates it.
|
|
next_name = None
|
|
i = 0
|
|
while not next_name:
|
|
candidate = f"true_graph_{i}"
|
|
if hasattr(proxy_mode.tracer.root, candidate):
|
|
i += 1
|
|
else:
|
|
next_name = candidate
|
|
|
|
true_name = next_name
|
|
false_name = f"false_graph_{i}"
|
|
assert not hasattr(proxy_mode.tracer.root, false_name)
|
|
|
|
proxy_mode.tracer.root.register_module(true_name, true_graph)
|
|
proxy_mode.tracer.root.register_module(false_name, false_graph)
|
|
|
|
args = (pred, true_graph, false_graph, operands)
|
|
|
|
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
|
|
|
|
out_proxy = proxy_mode.tracer.create_proxy(
|
|
"call_function", func_overload, proxy_args, {}, name="conditional"
|
|
)
|
|
|
|
# At this point, we're *guaranteed* that whether an output came from the
|
|
# true or false branch is indistinguishable. So, as this is just for tracing
|
|
# purposes, choose the true branch.
|
|
|
|
# TODO: Uhh.... it shouldn't matter, but changing this to true_fn results in
|
|
# a FakeTensorMode error :
|
|
# `Current active mode <class 'torch._subclasses.fake_tensor.FakeTensorMode'> not registered`
|
|
out = false_fn(*operands)
|
|
|
|
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
|
|
|
|
|
|
@cond.py_impl(DispatchKey.CompositeExplicitAutograd)
|
|
def cond_dense(pred, true_fn, false_fn, operands):
|
|
mode = _get_current_dispatch_mode()
|
|
assert mode is None, "Mode should never be enabled for CPU/CUDA key"
|
|
if pred:
|
|
return true_fn(*operands)
|
|
else:
|
|
return false_fn(*operands)
|
|
|
|
|
|
@cond.py_impl(DispatchKey.Autograd)
|
|
def cond_autograd(pred, true_fn, false_fn, *operands):
|
|
# TODO: support autograd
|
|
flat_operands, _ = tree_flatten([true_fn, false_fn] + [operands])
|
|
|
|
requires_grad = any(
|
|
isinstance(arg, torch.Tensor) and arg.requires_grad for arg in flat_operands
|
|
)
|
|
|
|
with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.AutogradCPU)):
|
|
result = cond(pred, true_fn, false_fn, *operands)
|
|
|
|
# If there is requires_grad, we delay the error until backward pass
|
|
if requires_grad:
|
|
# cond can only return one value
|
|
err_fn = torch._C._functions.DelayedError(
|
|
b"NYI: torch.cond doesn't support autograd",
|
|
1,
|
|
)
|
|
# Create aliases of the output that has requires_grad=True. We need
|
|
# at least one of the inputs to err_fn to require grad so that the
|
|
# output will have a grad_fn.
|
|
|
|
def fake_requires_grad(var):
|
|
if var is not None:
|
|
var = var.detach()
|
|
var.requires_grad = True
|
|
return var
|
|
|
|
return err_fn(fake_requires_grad(result))
|
|
|
|
return result
|
|
|
|
|
|
@cond.py_impl(ProxyTorchDispatchMode)
|
|
def inner(pred, true_fn, false_fn, operands):
|
|
mode = _get_current_dispatch_mode()
|
|
assert mode is not None, "Mode should always be enabled for python fallback key"
|
|
with _pop_mode_temporarily() as mode:
|
|
if mode.enable_tracing:
|
|
return trace_cond(mode, cond, pred, true_fn, false_fn, operands)
|
|
else:
|
|
return cond(pred, true_fn, false_fn, operands)
|
|
|
|
|
|
@cond.py_impl(FakeTensorMode)
|
|
def cond_fake_tensor_mode(pred, true_fn, false_fn, operands):
|
|
true_outs = true_fn(*operands)
|
|
flat_true_outs, _ = pytree.tree_flatten(true_outs)
|
|
flat_false_outs, _ = pytree.tree_flatten(false_fn(*operands))
|
|
if len(flat_true_outs) != len(flat_false_outs):
|
|
raise RuntimeError("Unmatched number of outputs from cond() branches.")
|
|
|
|
for true_out, false_out in zip(flat_true_outs, flat_false_outs):
|
|
true_meta = _extract_tensor_metadata(true_out)
|
|
false_meta = _extract_tensor_metadata(false_out)
|
|
if true_meta != false_meta:
|
|
raise RuntimeError(
|
|
f"Unmatched tensor metadata from cond() branches.\ntrue branch: {true_meta}, false branch: {false_meta}"
|
|
)
|
|
return true_outs
|
|
|
|
|
|
def _has_potential_branch_input_mutation(branch, inputs):
|
|
"""
|
|
Dispatch-trace the branch with inputs and check if
|
|
producing graph has mutable op on the input. This is
|
|
bit restrictive as the branch must be traceable.
|
|
"""
|
|
try:
|
|
gm = make_fx(branch)(*inputs)
|
|
except UnsupportedAliasMutationException:
|
|
# this can happen when nested cond is
|
|
# functionalized
|
|
return True
|
|
except Exception as e:
|
|
raise e
|
|
|
|
def _detect_input_mutation(gm):
|
|
input_nodes = set()
|
|
for node in gm.graph.nodes:
|
|
if node.op == "placeholder":
|
|
input_nodes.add(node)
|
|
if node.op == "call_function":
|
|
target = node.target
|
|
if (
|
|
isinstance(target, torch._ops.OpOverload)
|
|
and target._schema.is_mutable
|
|
):
|
|
for arg in node.args:
|
|
if arg in input_nodes:
|
|
return True
|
|
|
|
for _, module in gm.named_children():
|
|
if isinstance(module, torch.fx.GraphModule):
|
|
if _detect_input_mutation(module):
|
|
return True
|
|
|
|
return False
|
|
|
|
return _detect_input_mutation(gm)
|
|
|
|
|
|
def _has_potential_branch_input_alias(branch, inputs):
|
|
"""
|
|
Dispatch-trace the branch with inputs and check if
|
|
producing graph has output aliasing the branch input. This is
|
|
bit restrictive as the branch must be traceable.
|
|
"""
|
|
try:
|
|
gm = make_fx(branch)(*inputs)
|
|
|
|
except UnsupportedAliasMutationException:
|
|
# this can happen when nested cond is
|
|
# functionalized
|
|
return True
|
|
except Exception as e:
|
|
raise e
|
|
|
|
def _detect_input_alias(gm):
|
|
input_storages = set()
|
|
for node in gm.graph.nodes:
|
|
# We need to check existence of "val" because we reuse the logic here
|
|
# for map operator, where num_mapped_args is a scalar
|
|
# and doesn't have a "val" meta.
|
|
if node.op == "placeholder" and "val" in node.meta:
|
|
input_storages.add(StorageWeakRef(node.meta["val"]._typed_storage()))
|
|
if node.op == "output":
|
|
|
|
def check_alias(out):
|
|
if out is not None and "val" in out.meta:
|
|
out_storage = StorageWeakRef(out.meta["val"]._typed_storage())
|
|
return out_storage in input_storages
|
|
return False
|
|
|
|
if any(pytree.tree_flatten(pytree.tree_map(check_alias, node.args))[0]):
|
|
return True
|
|
|
|
for _, module in gm.named_children():
|
|
if isinstance(module, torch.fx.GraphModule) and _detect_input_alias(module):
|
|
return True
|
|
|
|
return False
|
|
|
|
return _detect_input_alias(gm)
|
|
|
|
|
|
@cond.py_impl(DispatchKey.Functionalize)
|
|
def cond_func(pred, true_fn, false_fn, inputs):
|
|
reapply_views = torch._C._functionalization_reapply_views_tls()
|
|
unwrapped_inputs = _unwrap_all_tensors_from_functional(
|
|
inputs, reapply_views=reapply_views
|
|
)
|
|
unwrapped_pred = _unwrap_all_tensors_from_functional(
|
|
pred, reapply_views=reapply_views
|
|
)
|
|
mode = "mutations_and_views" if reapply_views else "mutations"
|
|
with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
|
|
functional_true = functionalize(true_fn, remove=mode)
|
|
functional_false = functionalize(false_fn, remove=mode)
|
|
for branch in [true_fn, false_fn]:
|
|
if _has_potential_branch_input_mutation(branch, unwrapped_inputs):
|
|
raise UnsupportedAliasMutationException(
|
|
"One of torch.cond branch " "might be modifying the input!"
|
|
)
|
|
|
|
if _has_potential_branch_input_alias(branch, unwrapped_inputs):
|
|
raise UnsupportedAliasMutationException(
|
|
"One of torch.cond branch " "might be aliasing the input!"
|
|
)
|
|
|
|
cond_return = cond(
|
|
unwrapped_pred, functional_true, functional_false, unwrapped_inputs
|
|
)
|
|
return _wrap_all_tensors_to_functional(cond_return, level=0)
|
|
|
|
|
|
@cond.py_impl(torch._C._functorch.TransformType.Functionalize)
|
|
def cond_functionalize(interpreter, pred, true_fn, false_fn, inputs):
|
|
"""
|
|
Functionalization implementation for torch.cond. Currently:
|
|
1. We don't allow any input mutation inside the branches
|
|
2. Our check for above condition is not exhaustive
|
|
"""
|
|
reapply_views = interpreter.functionalize_add_back_views()
|
|
mode = "mutations_and_views" if reapply_views else "mutations"
|
|
# At this point, we will see functionalized tensors, so need to unwrap them first
|
|
unwrapped_inputs = _unwrap_all_tensors_from_functional(
|
|
inputs, reapply_views=reapply_views
|
|
)
|
|
unwrapped_pred = _unwrap_all_tensors_from_functional(
|
|
pred, reapply_views=reapply_views
|
|
)
|
|
|
|
functional_true_fn = functionalize(true_fn, remove=mode)
|
|
functional_false_fn = functionalize(false_fn, remove=mode)
|
|
|
|
with interpreter.lower():
|
|
for branch in [functional_true_fn, functional_false_fn]:
|
|
if _has_potential_branch_input_mutation(branch, unwrapped_inputs):
|
|
raise UnsupportedAliasMutationException(
|
|
"One of torch.cond branch " "might be modifying the input!"
|
|
)
|
|
for branch in [true_fn, false_fn]:
|
|
if _has_potential_branch_input_alias(branch, unwrapped_inputs):
|
|
raise UnsupportedAliasMutationException(
|
|
"One of torch.cond branch " "might be aliasing the input!"
|
|
)
|
|
|
|
cond_return = cond(
|
|
unwrapped_pred, functional_true_fn, functional_false_fn, unwrapped_inputs
|
|
)
|
|
return _wrap_all_tensors_to_functional(cond_return, level=interpreter.level())
|
|
|
|
|
|
# TODO(voz): Make this automatic for keys, this is very ugly atm
|
|
cond.fallthrough(DispatchKey.PythonDispatcher)
|
|
cond.fallthrough(DispatchKey.PythonTLSSnapshot)
|
|
cond.fallthrough(DispatchKey.ADInplaceOrView)
|
|
cond.fallthrough(DispatchKey.BackendSelect)
|
|
cond.fallthrough(DispatchKey.AutocastCPU)
|