mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This is an updated PR to equip cond with the autograd feature and replaces the old [PR](https://github.com/pytorch/pytorch/pull/126007)
@ydwu4 I tried to incorporate your requests already.
Currently there are two problems that I struggle with solving:
1. There seems to be an import issue when trying to import cond in `torch/__init__.py`, see [here](8a704035c9/torch/__init__.py (L1914-L1916)
). Therefore, I had to comment those lines, which resolved the import issues, but I believe cond is not proberly exposed as torch.cond.
2. I am not entirely sure how to deal with the opinfo test in `hop_db.py`
Co-authored-by: Yidi Wu <yidi@meta.com>
Co-authored-by: Xuehai Pan <XuehaiPan@outlook.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126911
Approved by: https://github.com/ydwu4
388 lines
14 KiB
Python
388 lines
14 KiB
Python
# mypy: allow-untyped-defs
|
|
import functools
|
|
from contextlib import contextmanager
|
|
from dataclasses import dataclass
|
|
from typing import Any, Callable
|
|
|
|
import torch
|
|
import torch.fx.traceback as fx_traceback
|
|
import torch.utils._pytree as pytree
|
|
from torch._ops import OperatorBase
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch.multiprocessing.reductions import StorageWeakRef
|
|
|
|
|
|
@dataclass
|
|
class UnsupportedAliasMutationException(RuntimeError):
|
|
reason: str
|
|
|
|
|
|
def autograd_not_implemented_inner(
|
|
operator: OperatorBase, delayed_error: bool, *args: Any, **kwargs: Any
|
|
) -> Any:
|
|
"""If autograd is enabled and any of the arguments require grad this will either
|
|
raise an error or return a DelayedError depending on the value of delayed.
|
|
|
|
Args:
|
|
operator: The Operator to call with the *args and **kwargs with
|
|
op_name: The name of the Operator
|
|
delayed_error: If True, return a DelayedError instead of raising an error
|
|
args: The flattened operands to the Operator
|
|
kwargs: The keyword arguments to the Operator
|
|
|
|
Raises:
|
|
RuntimeError: If autograd is enabled and any of the arguments to the Operator
|
|
"""
|
|
with torch._C._AutoDispatchBelowAutograd():
|
|
result = operator(*args, **kwargs)
|
|
flat_operands = pytree.arg_tree_leaves(*args)
|
|
if torch.is_grad_enabled() and any(
|
|
f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor)
|
|
):
|
|
if delayed_error:
|
|
err_fn = torch._C._functions.DelayedError(
|
|
f"Autograd not implemented for {str(operator)}",
|
|
1,
|
|
)
|
|
|
|
def fake_requires_grad(tensor):
|
|
if torch.is_floating_point(tensor) or torch.is_complex(tensor):
|
|
tensor = tensor.detach()
|
|
tensor.requires_grad = True
|
|
return tensor
|
|
|
|
return pytree.tree_map_only(
|
|
torch.Tensor, lambda x: err_fn(fake_requires_grad(x)), result
|
|
)
|
|
else:
|
|
raise RuntimeError(f"Autograd not implemented for {str(operator)}")
|
|
return result
|
|
|
|
|
|
def autograd_not_implemented(op: OperatorBase, deferred_error: bool) -> Callable:
|
|
def inner(*args, **kwargs):
|
|
return autograd_not_implemented_inner(op, deferred_error, *args, **kwargs)
|
|
|
|
return inner
|
|
|
|
|
|
def _maybe_run_with_interpreter(fn):
|
|
maybe_interpreted_fn = fn
|
|
if isinstance(fn, torch.fx.GraphModule) and fx_traceback.has_preserved_node_meta():
|
|
# Running graph with interpreter is needed for propagating the stack_trace
|
|
def graph_with_interpreter(*args):
|
|
with fx_traceback.preserve_node_meta():
|
|
return torch.fx.Interpreter(fn).run(*args)
|
|
|
|
maybe_interpreted_fn = graph_with_interpreter
|
|
return maybe_interpreted_fn
|
|
|
|
|
|
def reenter_make_fx(fn):
|
|
from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER
|
|
|
|
@functools.wraps(fn)
|
|
def wrapped(*args):
|
|
assert (
|
|
_CURRENT_MAKE_FX_TRACER is not None
|
|
), "Cannot reenter make_fx when we're not under a make_fx tracing session"
|
|
return _CURRENT_MAKE_FX_TRACER.trace_subgraph(
|
|
_maybe_run_with_interpreter(fn), *args
|
|
)
|
|
|
|
return wrapped
|
|
|
|
|
|
def _maybe_reenter_make_fx(fn):
|
|
from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER
|
|
|
|
if _CURRENT_MAKE_FX_TRACER is not None:
|
|
return reenter_make_fx(fn)
|
|
else:
|
|
return make_fx(fn)
|
|
|
|
|
|
@contextmanager
|
|
def _set_compilation_env():
|
|
_old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag
|
|
_old_is_inlining = torch._dynamo.config.inline_inbuilt_nn_modules
|
|
try:
|
|
# We need to turn off the is_fx_tracing_flag. Remove this flag check from dyanmo
|
|
# once we are confident fx tracing works with dynamo.
|
|
torch.fx._symbolic_trace._is_fx_tracing_flag = False
|
|
|
|
# TODO(anijain2305, export-team) For non-strict export with module
|
|
# stack info, the codepatch forces the nn module __getattr__ to
|
|
# ProxyAttr __getattr__ downstream. To circumvent the issue for now,
|
|
# skip inlining inbuilt nn modules for cond.
|
|
torch._dynamo.config.inline_inbuilt_nn_modules = False
|
|
yield
|
|
finally:
|
|
torch.fx._symbolic_trace._is_fx_tracing_flag = _old_is_tracing
|
|
torch._dynamo.config.inline_inbuilt_nn_modules = _old_is_inlining
|
|
|
|
|
|
def _has_potential_branch_input_mutation(branch, inputs, pre_dispatch=False):
|
|
"""
|
|
Dispatch-trace the branch with inputs and check if
|
|
producing graph has mutable op on the input. This is
|
|
bit restrictive as the branch must be traceable.
|
|
"""
|
|
try:
|
|
gm = make_fx(branch, pre_dispatch=pre_dispatch)(*inputs)
|
|
except UnsupportedAliasMutationException:
|
|
# this can happen when nested cond_op is
|
|
# functionalized
|
|
return True
|
|
except Exception as e:
|
|
raise e
|
|
|
|
def _detect_input_mutation(gm):
|
|
input_nodes = set()
|
|
for node in gm.graph.nodes:
|
|
if node.op == "placeholder":
|
|
input_nodes.add(node)
|
|
if node.op == "call_function":
|
|
target = node.target
|
|
if (
|
|
isinstance(target, torch._ops.OpOverload)
|
|
and target._schema.is_mutable
|
|
):
|
|
for arg in node.args:
|
|
if arg in input_nodes:
|
|
return True
|
|
|
|
for _, module in gm.named_children():
|
|
if isinstance(module, torch.fx.GraphModule):
|
|
if _detect_input_mutation(module):
|
|
return True
|
|
|
|
return False
|
|
|
|
return _detect_input_mutation(gm)
|
|
|
|
|
|
def _has_potential_branch_input_alias(branch, inputs, pre_dispatch=False):
|
|
"""
|
|
Dispatch-trace the branch with inputs and check if
|
|
producing graph has output aliasing the branch input. This is
|
|
bit restrictive as the branch must be traceable.
|
|
"""
|
|
try:
|
|
gm = make_fx(branch, pre_dispatch=pre_dispatch)(*inputs)
|
|
except UnsupportedAliasMutationException:
|
|
# this can happen when nested cond_op is
|
|
# functionalized
|
|
return True
|
|
except Exception as e:
|
|
raise e
|
|
|
|
def _detect_input_alias(gm):
|
|
input_storages = set()
|
|
for node in gm.graph.nodes:
|
|
# We need to check existence of "val" because we reuse the logic here
|
|
# for map operator, where num_mapped_args is a scalar
|
|
# and doesn't have a "val" meta.
|
|
if node.op == "placeholder" and "val" in node.meta:
|
|
input_storages.add(StorageWeakRef(node.meta["val"]._typed_storage()))
|
|
if node.op == "output":
|
|
|
|
def check_alias(out):
|
|
if out is not None and "val" in out.meta:
|
|
out_storage = StorageWeakRef(out.meta["val"]._typed_storage())
|
|
return out_storage in input_storages
|
|
return False
|
|
|
|
if any(pytree.tree_leaves(pytree.tree_map(check_alias, node.args))):
|
|
return True
|
|
|
|
for _, module in gm.named_children():
|
|
if isinstance(module, torch.fx.GraphModule) and _detect_input_alias(module):
|
|
return True
|
|
|
|
return False
|
|
|
|
return _detect_input_alias(gm)
|
|
|
|
|
|
def unique_graph_id(proxy_mode, prefix):
|
|
"""Returns a unique name and id for a graph to be added to a proxy_mode tracer"""
|
|
# There are probably better ways - I know that create_arg has some self incrementing name
|
|
# magic to it, but since we explicitly have to get the name for register_module,
|
|
# I was not sure how to do that. This kinda simulates it.
|
|
next_name = None
|
|
i = 0
|
|
while not next_name:
|
|
candidate = f"{prefix}_{i}"
|
|
if hasattr(proxy_mode.tracer.root, candidate):
|
|
i += 1
|
|
else:
|
|
next_name = candidate
|
|
return i, next_name
|
|
|
|
|
|
def _from_fun(t):
|
|
from torch._functorch.aot_autograd import from_fun
|
|
from torch._subclasses.functional_tensor import FunctionalTensor
|
|
|
|
if isinstance(t, torch.Tensor):
|
|
if t.dtype != torch.bool:
|
|
return torch.empty_strided(
|
|
t.size(),
|
|
t.stride(),
|
|
dtype=t.dtype,
|
|
requires_grad=t.requires_grad,
|
|
)
|
|
else:
|
|
# clone of a functional tensor produces a functional tensor
|
|
# but we want to avoid it so we clone a non-functional version
|
|
maybe_unfunc_t = t
|
|
if isinstance(t, FunctionalTensor):
|
|
torch._sync(t)
|
|
maybe_unfunc_t = from_fun(t)
|
|
elif torch._is_functional_tensor(t):
|
|
# need to handle both types of functionalization here:
|
|
# these are the tensors that came from the user,
|
|
# which could be either FunctionalTensorWrapper or FunctionalTensor
|
|
torch._sync(t)
|
|
maybe_unfunc_t = torch._from_functional_tensor(t)
|
|
return maybe_unfunc_t.clone()
|
|
return t
|
|
|
|
|
|
def clone_outputs_aliasing_inputs(args):
|
|
input_storage = {
|
|
StorageWeakRef(arg._typed_storage())
|
|
for arg in args
|
|
if isinstance(arg, torch.Tensor)
|
|
}
|
|
|
|
def maybe_clone(t):
|
|
if (
|
|
isinstance(t, torch.Tensor)
|
|
and StorageWeakRef(t._typed_storage()) in input_storage
|
|
):
|
|
return t.clone()
|
|
return t
|
|
|
|
return maybe_clone
|
|
|
|
|
|
def prepare_fw_with_masks(fn):
|
|
def fw_with_masks(*args):
|
|
fw_out = fn(*args)
|
|
return fw_out, [
|
|
True if isinstance(ret, torch.Tensor) and ret.requires_grad else False
|
|
for ret in fw_out
|
|
]
|
|
|
|
return fw_with_masks
|
|
|
|
|
|
# TODO: The parameter use_output_and_grad_bw is required because some operations
|
|
# that utilize this function, such as the while_loop, may require (grad, fwd_outputs)
|
|
def create_fw_bw_graph(fn, use_output_and_grad_bw, fw_inputs, fw_outputs):
|
|
from torch._functorch.aot_autograd import AOTConfig, create_joint
|
|
|
|
# Note:[HOP create fw_bw graph] We create "clean" environments for make_fx by suspending all dispatch keys
|
|
# between Autograd and Python key. Currently, we only suspend functionalization but more can be
|
|
# added when required. Will encounter two problems if we don't suspend functionalization:
|
|
#
|
|
# 1. make_fx fails to capture operations on input: the inputs are wrapped as _to_functional_tensor_wrapper,
|
|
# but they will be unwrapped before entering ProxyTorchDispatchMode as part of the dispatching.
|
|
# However, it's the outside wrapper that tracer creates proxies for. This casuses tracer fail to
|
|
# fetch the proxy for the inputs and fail to capture any operations on them.
|
|
#
|
|
# 2. make_fx fails to capture output: the outputs after ProxyTorchDispatchMode are further
|
|
# wrapped as FunctionalTensorWrapper in Functionalize key after return. However, the tracer
|
|
# only associates the inner tensor with proxy in ProxyTorchDispatchMode. Therefore,
|
|
# when creating the output node, it fails to associate the wrapped tensor with its proxy.
|
|
# Instead, it will create _tensor_constant as output.
|
|
|
|
dummy_aot_config = AOTConfig(
|
|
fw_compiler=None, # type: ignore[arg-type]
|
|
bw_compiler=None, # type: ignore[arg-type]
|
|
partition_fn=None, # type: ignore[arg-type]
|
|
decompositions={},
|
|
num_params_buffers=0,
|
|
aot_id=0,
|
|
keep_inference_input_mutations=False,
|
|
)
|
|
|
|
example_grad = [_from_fun(out) for out in fw_outputs]
|
|
num_grads = len(example_grad)
|
|
fw_graph = _maybe_reenter_make_fx(fn)(*fw_inputs)
|
|
|
|
def joint_fn(*joint_operands_grads):
|
|
if use_output_and_grad_bw:
|
|
grads = joint_operands_grads[0]
|
|
inputs = joint_operands_grads[1][-1:]
|
|
else:
|
|
grads = joint_operands_grads[:num_grads]
|
|
inputs = joint_operands_grads[num_grads:]
|
|
|
|
joint = create_joint(prepare_fw_with_masks(fn), aot_config=dummy_aot_config)
|
|
_, grads = joint(
|
|
list(inputs),
|
|
[grad for grad in grads if grad is not None and grad.requires_grad],
|
|
)
|
|
|
|
# In order to keep map functional for backward graph,
|
|
# we clone outputs that are aliasing inputs
|
|
maybe_clone = clone_outputs_aliasing_inputs(joint_operands_grads)
|
|
|
|
return pytree.tree_map(maybe_clone, grads)
|
|
|
|
if use_output_and_grad_bw:
|
|
example_xs_out = list(fw_inputs) + list(fw_outputs)
|
|
joint_graph = _maybe_reenter_make_fx(joint_fn)(
|
|
(list(example_grad), list(example_xs_out))
|
|
)
|
|
else:
|
|
example_xs_out = list(fw_inputs)
|
|
joint_graph = _maybe_reenter_make_fx(joint_fn)(
|
|
*(list(example_grad) + list(example_xs_out))
|
|
)
|
|
|
|
return fw_graph, joint_graph
|
|
|
|
|
|
def _unstack_pytree(xs):
|
|
flat_xs, inspec = pytree.tree_flatten(xs)
|
|
if not all(isinstance(xs, torch.Tensor) for xs in flat_xs):
|
|
raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}")
|
|
|
|
if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs):
|
|
raise RuntimeError(
|
|
f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}"
|
|
)
|
|
|
|
a = zip(*flat_xs)
|
|
|
|
pytrees = []
|
|
for tuple in a:
|
|
pytrees.append(pytree.tree_unflatten(tuple, inspec))
|
|
return pytrees
|
|
|
|
|
|
def _stack_pytree(pytrees):
|
|
flat_out = []
|
|
out_spec = None
|
|
for pt in pytrees:
|
|
flat_pt, out_spec = pytree.tree_flatten(pt)
|
|
flat_out.append(flat_pt)
|
|
assert out_spec is not None
|
|
b = zip(*flat_out)
|
|
stacked_out = []
|
|
for leaves in b:
|
|
if all(isinstance(leaf, torch.Tensor) for leaf in leaves):
|
|
stacked_out.append(torch.stack(leaves))
|
|
elif all(leaf is None for leaf in leaves):
|
|
# Backward graph can return None output when forward inputs doesn't require grad.
|
|
# When we eagerly execute backward graph, we need to call _stack_pytree on its output,
|
|
# therefore we need to deal with None output.
|
|
stacked_out.append(None) # type: ignore[arg-type]
|
|
else:
|
|
raise RuntimeError(f"Cannot stack {leaves}.")
|
|
return pytree.tree_unflatten(stacked_out, out_spec)
|