Files
pytorch/torch/_higher_order_ops/utils.py
Thomas Ortner 8ae1963a61 [Autograd] Cond Higher-Order Operation (#126911)
This is an updated PR to equip cond with the autograd feature and replaces the old [PR](https://github.com/pytorch/pytorch/pull/126007)

@ydwu4 I tried to incorporate your requests already.

Currently there are two problems that I struggle with solving:

1. There seems to be an import issue when trying to import cond in `torch/__init__.py`, see [here](8a704035c9/torch/__init__.py (L1914-L1916)). Therefore, I had to comment those lines, which resolved the import issues, but I believe cond is not proberly exposed as torch.cond.
2. I am not entirely sure how to deal with the opinfo test in `hop_db.py`

Co-authored-by: Yidi Wu <yidi@meta.com>
Co-authored-by: Xuehai Pan <XuehaiPan@outlook.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126911
Approved by: https://github.com/ydwu4
2024-07-22 23:18:19 +00:00

388 lines
14 KiB
Python

# mypy: allow-untyped-defs
import functools
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Callable
import torch
import torch.fx.traceback as fx_traceback
import torch.utils._pytree as pytree
from torch._ops import OperatorBase
from torch.fx.experimental.proxy_tensor import make_fx
from torch.multiprocessing.reductions import StorageWeakRef
@dataclass
class UnsupportedAliasMutationException(RuntimeError):
reason: str
def autograd_not_implemented_inner(
operator: OperatorBase, delayed_error: bool, *args: Any, **kwargs: Any
) -> Any:
"""If autograd is enabled and any of the arguments require grad this will either
raise an error or return a DelayedError depending on the value of delayed.
Args:
operator: The Operator to call with the *args and **kwargs with
op_name: The name of the Operator
delayed_error: If True, return a DelayedError instead of raising an error
args: The flattened operands to the Operator
kwargs: The keyword arguments to the Operator
Raises:
RuntimeError: If autograd is enabled and any of the arguments to the Operator
"""
with torch._C._AutoDispatchBelowAutograd():
result = operator(*args, **kwargs)
flat_operands = pytree.arg_tree_leaves(*args)
if torch.is_grad_enabled() and any(
f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor)
):
if delayed_error:
err_fn = torch._C._functions.DelayedError(
f"Autograd not implemented for {str(operator)}",
1,
)
def fake_requires_grad(tensor):
if torch.is_floating_point(tensor) or torch.is_complex(tensor):
tensor = tensor.detach()
tensor.requires_grad = True
return tensor
return pytree.tree_map_only(
torch.Tensor, lambda x: err_fn(fake_requires_grad(x)), result
)
else:
raise RuntimeError(f"Autograd not implemented for {str(operator)}")
return result
def autograd_not_implemented(op: OperatorBase, deferred_error: bool) -> Callable:
def inner(*args, **kwargs):
return autograd_not_implemented_inner(op, deferred_error, *args, **kwargs)
return inner
def _maybe_run_with_interpreter(fn):
maybe_interpreted_fn = fn
if isinstance(fn, torch.fx.GraphModule) and fx_traceback.has_preserved_node_meta():
# Running graph with interpreter is needed for propagating the stack_trace
def graph_with_interpreter(*args):
with fx_traceback.preserve_node_meta():
return torch.fx.Interpreter(fn).run(*args)
maybe_interpreted_fn = graph_with_interpreter
return maybe_interpreted_fn
def reenter_make_fx(fn):
from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER
@functools.wraps(fn)
def wrapped(*args):
assert (
_CURRENT_MAKE_FX_TRACER is not None
), "Cannot reenter make_fx when we're not under a make_fx tracing session"
return _CURRENT_MAKE_FX_TRACER.trace_subgraph(
_maybe_run_with_interpreter(fn), *args
)
return wrapped
def _maybe_reenter_make_fx(fn):
from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER
if _CURRENT_MAKE_FX_TRACER is not None:
return reenter_make_fx(fn)
else:
return make_fx(fn)
@contextmanager
def _set_compilation_env():
_old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag
_old_is_inlining = torch._dynamo.config.inline_inbuilt_nn_modules
try:
# We need to turn off the is_fx_tracing_flag. Remove this flag check from dyanmo
# once we are confident fx tracing works with dynamo.
torch.fx._symbolic_trace._is_fx_tracing_flag = False
# TODO(anijain2305, export-team) For non-strict export with module
# stack info, the codepatch forces the nn module __getattr__ to
# ProxyAttr __getattr__ downstream. To circumvent the issue for now,
# skip inlining inbuilt nn modules for cond.
torch._dynamo.config.inline_inbuilt_nn_modules = False
yield
finally:
torch.fx._symbolic_trace._is_fx_tracing_flag = _old_is_tracing
torch._dynamo.config.inline_inbuilt_nn_modules = _old_is_inlining
def _has_potential_branch_input_mutation(branch, inputs, pre_dispatch=False):
"""
Dispatch-trace the branch with inputs and check if
producing graph has mutable op on the input. This is
bit restrictive as the branch must be traceable.
"""
try:
gm = make_fx(branch, pre_dispatch=pre_dispatch)(*inputs)
except UnsupportedAliasMutationException:
# this can happen when nested cond_op is
# functionalized
return True
except Exception as e:
raise e
def _detect_input_mutation(gm):
input_nodes = set()
for node in gm.graph.nodes:
if node.op == "placeholder":
input_nodes.add(node)
if node.op == "call_function":
target = node.target
if (
isinstance(target, torch._ops.OpOverload)
and target._schema.is_mutable
):
for arg in node.args:
if arg in input_nodes:
return True
for _, module in gm.named_children():
if isinstance(module, torch.fx.GraphModule):
if _detect_input_mutation(module):
return True
return False
return _detect_input_mutation(gm)
def _has_potential_branch_input_alias(branch, inputs, pre_dispatch=False):
"""
Dispatch-trace the branch with inputs and check if
producing graph has output aliasing the branch input. This is
bit restrictive as the branch must be traceable.
"""
try:
gm = make_fx(branch, pre_dispatch=pre_dispatch)(*inputs)
except UnsupportedAliasMutationException:
# this can happen when nested cond_op is
# functionalized
return True
except Exception as e:
raise e
def _detect_input_alias(gm):
input_storages = set()
for node in gm.graph.nodes:
# We need to check existence of "val" because we reuse the logic here
# for map operator, where num_mapped_args is a scalar
# and doesn't have a "val" meta.
if node.op == "placeholder" and "val" in node.meta:
input_storages.add(StorageWeakRef(node.meta["val"]._typed_storage()))
if node.op == "output":
def check_alias(out):
if out is not None and "val" in out.meta:
out_storage = StorageWeakRef(out.meta["val"]._typed_storage())
return out_storage in input_storages
return False
if any(pytree.tree_leaves(pytree.tree_map(check_alias, node.args))):
return True
for _, module in gm.named_children():
if isinstance(module, torch.fx.GraphModule) and _detect_input_alias(module):
return True
return False
return _detect_input_alias(gm)
def unique_graph_id(proxy_mode, prefix):
"""Returns a unique name and id for a graph to be added to a proxy_mode tracer"""
# There are probably better ways - I know that create_arg has some self incrementing name
# magic to it, but since we explicitly have to get the name for register_module,
# I was not sure how to do that. This kinda simulates it.
next_name = None
i = 0
while not next_name:
candidate = f"{prefix}_{i}"
if hasattr(proxy_mode.tracer.root, candidate):
i += 1
else:
next_name = candidate
return i, next_name
def _from_fun(t):
from torch._functorch.aot_autograd import from_fun
from torch._subclasses.functional_tensor import FunctionalTensor
if isinstance(t, torch.Tensor):
if t.dtype != torch.bool:
return torch.empty_strided(
t.size(),
t.stride(),
dtype=t.dtype,
requires_grad=t.requires_grad,
)
else:
# clone of a functional tensor produces a functional tensor
# but we want to avoid it so we clone a non-functional version
maybe_unfunc_t = t
if isinstance(t, FunctionalTensor):
torch._sync(t)
maybe_unfunc_t = from_fun(t)
elif torch._is_functional_tensor(t):
# need to handle both types of functionalization here:
# these are the tensors that came from the user,
# which could be either FunctionalTensorWrapper or FunctionalTensor
torch._sync(t)
maybe_unfunc_t = torch._from_functional_tensor(t)
return maybe_unfunc_t.clone()
return t
def clone_outputs_aliasing_inputs(args):
input_storage = {
StorageWeakRef(arg._typed_storage())
for arg in args
if isinstance(arg, torch.Tensor)
}
def maybe_clone(t):
if (
isinstance(t, torch.Tensor)
and StorageWeakRef(t._typed_storage()) in input_storage
):
return t.clone()
return t
return maybe_clone
def prepare_fw_with_masks(fn):
def fw_with_masks(*args):
fw_out = fn(*args)
return fw_out, [
True if isinstance(ret, torch.Tensor) and ret.requires_grad else False
for ret in fw_out
]
return fw_with_masks
# TODO: The parameter use_output_and_grad_bw is required because some operations
# that utilize this function, such as the while_loop, may require (grad, fwd_outputs)
def create_fw_bw_graph(fn, use_output_and_grad_bw, fw_inputs, fw_outputs):
from torch._functorch.aot_autograd import AOTConfig, create_joint
# Note:[HOP create fw_bw graph] We create "clean" environments for make_fx by suspending all dispatch keys
# between Autograd and Python key. Currently, we only suspend functionalization but more can be
# added when required. Will encounter two problems if we don't suspend functionalization:
#
# 1. make_fx fails to capture operations on input: the inputs are wrapped as _to_functional_tensor_wrapper,
# but they will be unwrapped before entering ProxyTorchDispatchMode as part of the dispatching.
# However, it's the outside wrapper that tracer creates proxies for. This casuses tracer fail to
# fetch the proxy for the inputs and fail to capture any operations on them.
#
# 2. make_fx fails to capture output: the outputs after ProxyTorchDispatchMode are further
# wrapped as FunctionalTensorWrapper in Functionalize key after return. However, the tracer
# only associates the inner tensor with proxy in ProxyTorchDispatchMode. Therefore,
# when creating the output node, it fails to associate the wrapped tensor with its proxy.
# Instead, it will create _tensor_constant as output.
dummy_aot_config = AOTConfig(
fw_compiler=None, # type: ignore[arg-type]
bw_compiler=None, # type: ignore[arg-type]
partition_fn=None, # type: ignore[arg-type]
decompositions={},
num_params_buffers=0,
aot_id=0,
keep_inference_input_mutations=False,
)
example_grad = [_from_fun(out) for out in fw_outputs]
num_grads = len(example_grad)
fw_graph = _maybe_reenter_make_fx(fn)(*fw_inputs)
def joint_fn(*joint_operands_grads):
if use_output_and_grad_bw:
grads = joint_operands_grads[0]
inputs = joint_operands_grads[1][-1:]
else:
grads = joint_operands_grads[:num_grads]
inputs = joint_operands_grads[num_grads:]
joint = create_joint(prepare_fw_with_masks(fn), aot_config=dummy_aot_config)
_, grads = joint(
list(inputs),
[grad for grad in grads if grad is not None and grad.requires_grad],
)
# In order to keep map functional for backward graph,
# we clone outputs that are aliasing inputs
maybe_clone = clone_outputs_aliasing_inputs(joint_operands_grads)
return pytree.tree_map(maybe_clone, grads)
if use_output_and_grad_bw:
example_xs_out = list(fw_inputs) + list(fw_outputs)
joint_graph = _maybe_reenter_make_fx(joint_fn)(
(list(example_grad), list(example_xs_out))
)
else:
example_xs_out = list(fw_inputs)
joint_graph = _maybe_reenter_make_fx(joint_fn)(
*(list(example_grad) + list(example_xs_out))
)
return fw_graph, joint_graph
def _unstack_pytree(xs):
flat_xs, inspec = pytree.tree_flatten(xs)
if not all(isinstance(xs, torch.Tensor) for xs in flat_xs):
raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}")
if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs):
raise RuntimeError(
f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}"
)
a = zip(*flat_xs)
pytrees = []
for tuple in a:
pytrees.append(pytree.tree_unflatten(tuple, inspec))
return pytrees
def _stack_pytree(pytrees):
flat_out = []
out_spec = None
for pt in pytrees:
flat_pt, out_spec = pytree.tree_flatten(pt)
flat_out.append(flat_pt)
assert out_spec is not None
b = zip(*flat_out)
stacked_out = []
for leaves in b:
if all(isinstance(leaf, torch.Tensor) for leaf in leaves):
stacked_out.append(torch.stack(leaves))
elif all(leaf is None for leaf in leaves):
# Backward graph can return None output when forward inputs doesn't require grad.
# When we eagerly execute backward graph, we need to call _stack_pytree on its output,
# therefore we need to deal with None output.
stacked_out.append(None) # type: ignore[arg-type]
else:
raise RuntimeError(f"Cannot stack {leaves}.")
return pytree.tree_unflatten(stacked_out, out_spec)