Files
pytorch/torch/_higher_order_ops/utils.py

549 lines
20 KiB
Python

# mypy: allow-untyped-defs
import functools
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Callable, Union
import torch
import torch.fx.traceback as fx_traceback
import torch.utils._pytree as pytree
from torch._ops import OperatorBase
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.passes.shape_prop import TensorMetadata
from torch.multiprocessing.reductions import StorageWeakRef
@dataclass
class UnsupportedAliasMutationException(RuntimeError):
reason: str
def autograd_not_implemented_inner(
operator: OperatorBase, delayed_error: bool, *args: Any, **kwargs: Any
) -> Any:
"""If autograd is enabled and any of the arguments require grad this will either
raise an error or return a DelayedError depending on the value of delayed.
Args:
operator: The Operator to call with the *args and **kwargs with
op_name: The name of the Operator
delayed_error: If True, return a DelayedError instead of raising an error
args: The flattened operands to the Operator
kwargs: The keyword arguments to the Operator
Raises:
RuntimeError: If autograd is enabled and any of the arguments to the Operator
"""
with torch._C._AutoDispatchBelowAutograd():
result = operator(*args, **kwargs)
flat_operands = pytree.arg_tree_leaves(*args)
if torch.is_grad_enabled() and any(
f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor)
):
if delayed_error:
err_fn = torch._C._functions.DelayedError(
f"Autograd not implemented for {str(operator)}",
1,
)
def fake_requires_grad(tensor):
if torch.is_floating_point(tensor) or torch.is_complex(tensor):
tensor = tensor.detach()
tensor.requires_grad = True
return tensor
return pytree.tree_map_only(
torch.Tensor, lambda x: err_fn(fake_requires_grad(x)), result
)
else:
raise RuntimeError(f"Autograd not implemented for {str(operator)}")
return result
def autograd_not_implemented(op: OperatorBase, deferred_error: bool) -> Callable:
def inner(*args, **kwargs):
return autograd_not_implemented_inner(op, deferred_error, *args, **kwargs)
return inner
def _maybe_run_with_interpreter(fn):
maybe_interpreted_fn = fn
if isinstance(fn, torch.fx.GraphModule) and fx_traceback.has_preserved_node_meta():
# Running graph with interpreter is needed for propagating the stack_trace
def graph_with_interpreter(*args):
with fx_traceback.preserve_node_meta():
return torch.fx.Interpreter(fn).run(*args)
maybe_interpreted_fn = graph_with_interpreter
return maybe_interpreted_fn
def reenter_make_fx(fn):
from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER
@functools.wraps(fn)
def wrapped(*args):
assert (
_CURRENT_MAKE_FX_TRACER is not None
), "Cannot reenter make_fx when we're not under a make_fx tracing session"
return _CURRENT_MAKE_FX_TRACER.trace_subgraph(
_maybe_run_with_interpreter(fn), *args
)
return wrapped
def _maybe_reenter_make_fx(fn):
from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER
if _CURRENT_MAKE_FX_TRACER is not None:
return reenter_make_fx(fn)
else:
def _maybe_make_fx_with_fake_mode(fn):
@functools.wraps(fn)
def wrapped(*args):
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode(args)
if fake_mode is None:
# we creaeta a fake_mode here to make sure we could
# trace the graph with data-dependent calls e.g. .item()
return make_fx(fn, tracing_mode="fake")(*args)
# Tracing with real if all inputs have been fakfied
return make_fx(fn)(*args)
return wrapped
return _maybe_make_fx_with_fake_mode(fn)
@contextmanager
def _set_compilation_env():
_old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag
try:
# We need to turn off the is_fx_tracing_flag. Remove this flag check from dyanmo
# once we are confident fx tracing works with dynamo.
torch.fx._symbolic_trace._is_fx_tracing_flag = False
yield
finally:
torch.fx._symbolic_trace._is_fx_tracing_flag = _old_is_tracing
def _detect_input_mutation(gm):
input_nodes = set()
for node in gm.graph.nodes:
if node.op == "placeholder":
input_nodes.add(node)
if node.op == "call_function":
target = node.target
if isinstance(target, torch._ops.OpOverload) and target._schema.is_mutable:
for arg in node.args:
if arg in input_nodes:
return True
for _, module in gm.named_children():
if isinstance(module, torch.fx.GraphModule):
if _detect_input_mutation(module):
return True
return False
def _detect_input_alias(gm):
input_storages = set()
for node in gm.graph.nodes:
# We need to check existence of "val" because we reuse the logic here
# for map operator, where num_mapped_args is a scalar
# and doesn't have a "val" meta.
if (
node.op == "placeholder"
and "val" in node.meta
and isinstance(node.meta["val"], torch.Tensor)
):
input_storages.add(StorageWeakRef(node.meta["val"]._typed_storage()))
if node.op == "output":
def check_alias(out):
if (
# out can be an integer
isinstance(out, torch.fx.Node)
and out is not None
and "val" in out.meta
and isinstance(out.meta["val"], torch.Tensor)
):
out_storage = StorageWeakRef(out.meta["val"]._typed_storage())
return out_storage in input_storages
return False
if any(pytree.tree_leaves(pytree.tree_map(check_alias, node.args))):
return True
for _, module in gm.named_children():
if isinstance(module, torch.fx.GraphModule) and _detect_input_alias(module):
return True
return False
def has_potential_input_alias_or_mutation(gm, inputs, pre_dispatch=False):
try:
gm = make_fx(gm, pre_dispatch=pre_dispatch)(*inputs)
except UnsupportedAliasMutationException:
# this can happen when nested cond_op is
# functionalized
return True
except Exception as e:
raise e
return _detect_input_mutation(gm) or _detect_input_alias(gm)
def _has_potential_branch_input_mutation(branch, inputs, pre_dispatch=False):
"""
Dispatch-trace the branch with inputs and check if
producing graph has mutable op on the input. This is
bit restrictive as the branch must be traceable.
"""
try:
gm = make_fx(branch, pre_dispatch=pre_dispatch)(*inputs)
except UnsupportedAliasMutationException:
# this can happen when nested cond_op is
# functionalized
return True
except Exception as e:
raise e
return _detect_input_mutation(gm)
def _has_potential_branch_input_alias(branch, inputs, pre_dispatch=False):
"""
Dispatch-trace the branch with inputs and check if
producing graph has output aliasing the branch input. This is
bit restrictive as the branch must be traceable.
"""
try:
gm = make_fx(branch, pre_dispatch=pre_dispatch)(*inputs)
except UnsupportedAliasMutationException:
# this can happen when nested cond_op is
# functionalized
return True
except Exception as e:
raise e
return _detect_input_alias(gm)
def unique_graph_id(proxy_mode, prefix):
"""Returns a unique name and id for a graph to be added to a proxy_mode tracer"""
# There are probably better ways - I know that create_arg has some self incrementing name
# magic to it, but since we explicitly have to get the name for register_module,
# I was not sure how to do that. This kinda simulates it.
next_name = None
i = 0
while not next_name:
candidate = f"{prefix}_{i}"
if hasattr(proxy_mode.tracer.root, candidate):
i += 1
else:
next_name = candidate
return i, next_name
def _from_fun(t):
from torch._functorch.aot_autograd import from_fun
from torch._subclasses.functional_tensor import FunctionalTensor
if isinstance(t, torch.Tensor):
if t.dtype != torch.bool:
return torch.empty_strided(
t.size(),
t.stride(),
dtype=t.dtype,
requires_grad=t.requires_grad,
device=t.device,
)
else:
# clone of a functional tensor produces a functional tensor
# but we want to avoid it so we clone a non-functional version
maybe_unfunc_t = t
if isinstance(t, FunctionalTensor):
torch._sync(t)
maybe_unfunc_t = from_fun(t)
elif torch._is_functional_tensor(t):
# need to handle both types of functionalization here:
# these are the tensors that came from the user,
# which could be either FunctionalTensorWrapper or FunctionalTensor
torch._sync(t)
maybe_unfunc_t = torch._from_functional_tensor(t)
return maybe_unfunc_t.clone()
return t
def clone_outputs_aliasing_inputs(args):
input_storage = {
StorageWeakRef(arg._typed_storage())
for arg in args
if isinstance(arg, torch.Tensor)
}
def maybe_clone(t):
if (
isinstance(t, torch.Tensor)
and StorageWeakRef(t._typed_storage()) in input_storage
):
return t.clone()
return t
return maybe_clone
def prepare_fw_with_masks(fn):
def fw_with_masks(*args):
fw_out = fn(*args)
return fw_out, [
True if isinstance(ret, torch.Tensor) and ret.requires_grad else False
for ret in fw_out
]
return fw_with_masks
# This function replaces None gradients with all-zero gradients.
# `None` gradients are problematic for CUDA graphs. Those gradients are
# replaced with an all-zero tensor for better optimization
def unmask_none_gradients(grads, operands):
allowed_types = (torch.Tensor, int, torch.SymInt)
assert all(
isinstance(o, allowed_types) for o in operands
), f"operands can only be of {allowed_types} but got {[type(o) for o in operands]}"
unmasked_grads = []
for g, o in zip(grads, operands):
if g is not None:
unmasked_grads.append(g)
else:
# In case the operand is an int or a torch.SymInt, return None
# This can happen for lifted_arguments. E.g., the shapes of a dynamic tensor are lifted and passed
# as additional arguments
unmasked_grads.append(
torch.zeros_like(o) if isinstance(o, torch.Tensor) else None
)
return unmasked_grads
# TODO: The parameter use_output_and_grad_bw is required because some operations
# that utilize this function, such as the while_loop, may require (grad, fwd_outputs)
def create_fw_bw_graph(fn, use_output_and_grad_bw, fw_inputs, fw_outputs):
from torch._functorch.aot_autograd import AOTConfig, create_joint
# Note:[HOP create fw_bw graph] We create "clean" environments for make_fx by suspending all dispatch keys
# between Autograd and Python key. Currently, we only suspend functionalization but more can be
# added when required. Will encounter two problems if we don't suspend functionalization:
#
# 1. make_fx fails to capture operations on input: the inputs are wrapped as _to_functional_tensor_wrapper,
# but they will be unwrapped before entering ProxyTorchDispatchMode as part of the dispatching.
# However, it's the outside wrapper that tracer creates proxies for. This casuses tracer fail to
# fetch the proxy for the inputs and fail to capture any operations on them.
#
# 2. make_fx fails to capture output: the outputs after ProxyTorchDispatchMode are further
# wrapped as FunctionalTensorWrapper in Functionalize key after return. However, the tracer
# only associates the inner tensor with proxy in ProxyTorchDispatchMode. Therefore,
# when creating the output node, it fails to associate the wrapped tensor with its proxy.
# Instead, it will create _tensor_constant as output.
dummy_aot_config = AOTConfig(
fw_compiler=None, # type: ignore[arg-type]
bw_compiler=None, # type: ignore[arg-type]
partition_fn=None, # type: ignore[arg-type]
decompositions={},
num_params_buffers=0,
aot_id=0,
keep_inference_input_mutations=False,
)
example_grad = [_from_fun(out) for out in fw_outputs]
num_grads = len(example_grad)
fw_graph = _maybe_reenter_make_fx(fn)(*fw_inputs)
def joint_fn(*joint_operands_grads):
if use_output_and_grad_bw:
grads = joint_operands_grads[0]
inputs = joint_operands_grads[1][-1:]
else:
grads = joint_operands_grads[:num_grads]
inputs = joint_operands_grads[num_grads:]
joint = create_joint(prepare_fw_with_masks(fn), aot_config=dummy_aot_config)
_, grads = joint(
list(inputs),
[grad for grad in grads if grad is not None and grad.requires_grad],
)
# Unmask None gradients to all-zero gradients
unmasked_grads = unmask_none_gradients(grads, inputs)
# In order to keep map functional for backward graph,
# we clone outputs that are aliasing inputs
maybe_clone = clone_outputs_aliasing_inputs(joint_operands_grads)
return pytree.tree_map(maybe_clone, unmasked_grads)
if use_output_and_grad_bw:
example_xs_out = list(fw_inputs) + list(fw_outputs)
joint_graph = _maybe_reenter_make_fx(joint_fn)(
(list(example_grad), list(example_xs_out))
)
else:
example_xs_out = list(fw_inputs)
joint_graph = _maybe_reenter_make_fx(joint_fn)(
*(list(example_grad) + list(example_xs_out))
)
return fw_graph, joint_graph
def _unstack_pytree(xs):
flat_xs, inspec = pytree.tree_flatten(xs)
if not all(isinstance(xs, torch.Tensor) for xs in flat_xs):
raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}")
if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs):
raise RuntimeError(
f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}"
)
a = zip(*flat_xs)
pytrees = [pytree.tree_unflatten(tuple, inspec) for tuple in a]
return pytrees
def _stack_pytree(pytrees):
flat_out = []
out_spec = None
for pt in pytrees:
flat_pt, out_spec = pytree.tree_flatten(pt)
flat_out.append(flat_pt)
assert out_spec is not None
b = zip(*flat_out)
stacked_out = []
for leaves in b:
if all(isinstance(leaf, torch.Tensor) for leaf in leaves):
stacked_out.append(torch.stack(leaves))
elif all(leaf is None for leaf in leaves):
# Backward graph can return None output when forward inputs doesn't require grad.
# When we eagerly execute backward graph, we need to call _stack_pytree on its output,
# therefore we need to deal with None output.
stacked_out.append(None) # type: ignore[arg-type]
else:
raise RuntimeError(f"Cannot stack {leaves}.")
return pytree.tree_unflatten(stacked_out, out_spec)
# We cannot call save_for_backward for symints. This helper function
# can be used to save symints as direct attributes of ctx in autograd.Function.
#
# For example, if args = (x, y, s0, z, s1),
# save_tensors_and_symints_for_backward will partition the args into two lists, and a bookkeeping list pos:
# partitioned_args[0] = (x, y, z)
# partitioned_args[1] = (s0, s1)
# pos = (0, 0, 1, 0, 1)
# pos list keeps track of which partition the args
# is partitioned into in order to recover it in saved_tensors_and_symints.
#
# In saved_tensors_and_symints, we can recover the original args by:
# iterating over the pos list and pop one item from the front of paritioned_args[pos[i]].
# We use t_idx and s_idx to keep track of the next index of the item we are going to pop for the two lists.
def save_tensors_and_symints_for_backward(ctx, args):
assert all(
isinstance(arg, (torch.Tensor, torch.SymInt, int, type(None))) for arg in args
), args
partitioned_args: list[Any] = [[], []]
pos = []
for i, arg in enumerate(args):
idx = 0 if isinstance(arg, torch.Tensor) else 1
partitioned_args[idx].append(arg)
pos.append(idx)
assert not hasattr(ctx, "sym_int_args"), "ctx already has sym_int_args attribute."
assert not hasattr(ctx, "pos"), "ctx already has pos attribute."
ctx.save_for_backward(*partitioned_args[0])
ctx.sym_int_args = partitioned_args[1]
ctx.pos = pos
def saved_tensors_and_symints(ctx):
args = []
t_idx = 0
s_idx = 0
saved_tensors = ctx.saved_tensors
for p in ctx.pos:
if p == 0:
args.append(saved_tensors[t_idx])
t_idx += 1
else:
args.append(ctx.sym_int_args[s_idx])
s_idx += 1
assert t_idx + s_idx == len(ctx.pos)
return tuple(args)
def get_dummy_aot_autograd_config():
from torch._functorch.aot_autograd import AOTConfig
return AOTConfig(
fw_compiler=None, # type: ignore[arg-type]
bw_compiler=None, # type: ignore[arg-type]
partition_fn=None, # type: ignore[arg-type]
decompositions={},
num_params_buffers=0,
aot_id=0,
keep_inference_input_mutations=False,
)
# Slices off the first element of a given dimension
def first_slice_copy(t: torch.Tensor, dim: int = 0) -> torch.Tensor:
return torch.select_copy(t, dim, 0)
# Reports the difference between meta of two tensors in a string
def diff_tensor_meta(
meta1: TensorMetadata, meta2: TensorMetadata, check_grad=True
) -> list[str]:
from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode
pair_diffs = []
for meta_name in TensorMetadata._fields:
if not check_grad and meta_name == "requires_grad":
continue
val1 = getattr(meta1, meta_name)
val2 = getattr(meta2, meta_name)
try:
if val1 != val2:
pair_diffs.append(f"'{meta_name}: {val1} vs {val2}'")
except GuardOnDataDependentSymNode as _:
pair_diffs.append(f"'{meta_name}: {val1} vs {val2}'")
continue
return pair_diffs
# Note [lifted arg types in hop]
# For dynamoed hops, we automatically lift the free symbols in tensors as arguments.
# This has implications for the types of lifted args for different dispatch keys:
# 1. functionalization, FakeTensorMode, ProxyTorchDispatchMode, Autograd need to support torch.Symint
# lifted args because it's on the path of torch.compile(dynamic=True).
# 2. functionalization, FakeTensorMode, ProxyTorchDispatchMode, Autograd, CompositeExplicitAutograd need
# to support int arguments. In the eager run case, we re-trace the subgraph in AutogradKey, so inner
# hops may receive int inputs from the shape of outer tensor inputs.
# However, CompositeExplicitAutograd won't receive SymInt inputs because it only accepts real tensor inputs.
def validate_subgraph_args_types(lifted_args: Union[tuple[Any, ...], list[Any]]):
allowed_types = (torch.Tensor, int, torch.SymInt)
assert all(
isinstance(arg, (torch.Tensor, int, torch.SymInt)) for arg in lifted_args
), f"{lifted_args} can only be of {allowed_types} but got {tuple(type(arg) for arg in lifted_args)}"