mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145202 Approved by: https://github.com/bobrenjc93
549 lines
20 KiB
Python
549 lines
20 KiB
Python
# mypy: allow-untyped-defs
|
|
import functools
|
|
from contextlib import contextmanager
|
|
from dataclasses import dataclass
|
|
from typing import Any, Callable, Union
|
|
|
|
import torch
|
|
import torch.fx.traceback as fx_traceback
|
|
import torch.utils._pytree as pytree
|
|
from torch._ops import OperatorBase
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch.fx.passes.shape_prop import TensorMetadata
|
|
from torch.multiprocessing.reductions import StorageWeakRef
|
|
|
|
|
|
@dataclass
|
|
class UnsupportedAliasMutationException(RuntimeError):
|
|
reason: str
|
|
|
|
|
|
def autograd_not_implemented_inner(
|
|
operator: OperatorBase, delayed_error: bool, *args: Any, **kwargs: Any
|
|
) -> Any:
|
|
"""If autograd is enabled and any of the arguments require grad this will either
|
|
raise an error or return a DelayedError depending on the value of delayed.
|
|
|
|
Args:
|
|
operator: The Operator to call with the *args and **kwargs with
|
|
op_name: The name of the Operator
|
|
delayed_error: If True, return a DelayedError instead of raising an error
|
|
args: The flattened operands to the Operator
|
|
kwargs: The keyword arguments to the Operator
|
|
|
|
Raises:
|
|
RuntimeError: If autograd is enabled and any of the arguments to the Operator
|
|
"""
|
|
with torch._C._AutoDispatchBelowAutograd():
|
|
result = operator(*args, **kwargs)
|
|
flat_operands = pytree.arg_tree_leaves(*args)
|
|
if torch.is_grad_enabled() and any(
|
|
f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor)
|
|
):
|
|
if delayed_error:
|
|
err_fn = torch._C._functions.DelayedError(
|
|
f"Autograd not implemented for {str(operator)}",
|
|
1,
|
|
)
|
|
|
|
def fake_requires_grad(tensor):
|
|
if torch.is_floating_point(tensor) or torch.is_complex(tensor):
|
|
tensor = tensor.detach()
|
|
tensor.requires_grad = True
|
|
return tensor
|
|
|
|
return pytree.tree_map_only(
|
|
torch.Tensor, lambda x: err_fn(fake_requires_grad(x)), result
|
|
)
|
|
else:
|
|
raise RuntimeError(f"Autograd not implemented for {str(operator)}")
|
|
return result
|
|
|
|
|
|
def autograd_not_implemented(op: OperatorBase, deferred_error: bool) -> Callable:
|
|
def inner(*args, **kwargs):
|
|
return autograd_not_implemented_inner(op, deferred_error, *args, **kwargs)
|
|
|
|
return inner
|
|
|
|
|
|
def _maybe_run_with_interpreter(fn):
|
|
maybe_interpreted_fn = fn
|
|
if isinstance(fn, torch.fx.GraphModule) and fx_traceback.has_preserved_node_meta():
|
|
# Running graph with interpreter is needed for propagating the stack_trace
|
|
def graph_with_interpreter(*args):
|
|
with fx_traceback.preserve_node_meta():
|
|
return torch.fx.Interpreter(fn).run(*args)
|
|
|
|
maybe_interpreted_fn = graph_with_interpreter
|
|
return maybe_interpreted_fn
|
|
|
|
|
|
def reenter_make_fx(fn):
|
|
from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER
|
|
|
|
@functools.wraps(fn)
|
|
def wrapped(*args):
|
|
assert (
|
|
_CURRENT_MAKE_FX_TRACER is not None
|
|
), "Cannot reenter make_fx when we're not under a make_fx tracing session"
|
|
return _CURRENT_MAKE_FX_TRACER.trace_subgraph(
|
|
_maybe_run_with_interpreter(fn), *args
|
|
)
|
|
|
|
return wrapped
|
|
|
|
|
|
def _maybe_reenter_make_fx(fn):
|
|
from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER
|
|
|
|
if _CURRENT_MAKE_FX_TRACER is not None:
|
|
return reenter_make_fx(fn)
|
|
else:
|
|
|
|
def _maybe_make_fx_with_fake_mode(fn):
|
|
@functools.wraps(fn)
|
|
def wrapped(*args):
|
|
from torch._guards import detect_fake_mode
|
|
|
|
fake_mode = detect_fake_mode(args)
|
|
if fake_mode is None:
|
|
# we creaeta a fake_mode here to make sure we could
|
|
# trace the graph with data-dependent calls e.g. .item()
|
|
return make_fx(fn, tracing_mode="fake")(*args)
|
|
# Tracing with real if all inputs have been fakfied
|
|
return make_fx(fn)(*args)
|
|
|
|
return wrapped
|
|
|
|
return _maybe_make_fx_with_fake_mode(fn)
|
|
|
|
|
|
@contextmanager
|
|
def _set_compilation_env():
|
|
_old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag
|
|
try:
|
|
# We need to turn off the is_fx_tracing_flag. Remove this flag check from dyanmo
|
|
# once we are confident fx tracing works with dynamo.
|
|
torch.fx._symbolic_trace._is_fx_tracing_flag = False
|
|
yield
|
|
finally:
|
|
torch.fx._symbolic_trace._is_fx_tracing_flag = _old_is_tracing
|
|
|
|
|
|
def _detect_input_mutation(gm):
|
|
input_nodes = set()
|
|
for node in gm.graph.nodes:
|
|
if node.op == "placeholder":
|
|
input_nodes.add(node)
|
|
if node.op == "call_function":
|
|
target = node.target
|
|
if isinstance(target, torch._ops.OpOverload) and target._schema.is_mutable:
|
|
for arg in node.args:
|
|
if arg in input_nodes:
|
|
return True
|
|
|
|
for _, module in gm.named_children():
|
|
if isinstance(module, torch.fx.GraphModule):
|
|
if _detect_input_mutation(module):
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def _detect_input_alias(gm):
|
|
input_storages = set()
|
|
for node in gm.graph.nodes:
|
|
# We need to check existence of "val" because we reuse the logic here
|
|
# for map operator, where num_mapped_args is a scalar
|
|
# and doesn't have a "val" meta.
|
|
if (
|
|
node.op == "placeholder"
|
|
and "val" in node.meta
|
|
and isinstance(node.meta["val"], torch.Tensor)
|
|
):
|
|
input_storages.add(StorageWeakRef(node.meta["val"]._typed_storage()))
|
|
if node.op == "output":
|
|
|
|
def check_alias(out):
|
|
if (
|
|
# out can be an integer
|
|
isinstance(out, torch.fx.Node)
|
|
and out is not None
|
|
and "val" in out.meta
|
|
and isinstance(out.meta["val"], torch.Tensor)
|
|
):
|
|
out_storage = StorageWeakRef(out.meta["val"]._typed_storage())
|
|
return out_storage in input_storages
|
|
return False
|
|
|
|
if any(pytree.tree_leaves(pytree.tree_map(check_alias, node.args))):
|
|
return True
|
|
|
|
for _, module in gm.named_children():
|
|
if isinstance(module, torch.fx.GraphModule) and _detect_input_alias(module):
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def has_potential_input_alias_or_mutation(gm, inputs, pre_dispatch=False):
|
|
try:
|
|
gm = make_fx(gm, pre_dispatch=pre_dispatch)(*inputs)
|
|
except UnsupportedAliasMutationException:
|
|
# this can happen when nested cond_op is
|
|
# functionalized
|
|
return True
|
|
except Exception as e:
|
|
raise e
|
|
|
|
return _detect_input_mutation(gm) or _detect_input_alias(gm)
|
|
|
|
|
|
def _has_potential_branch_input_mutation(branch, inputs, pre_dispatch=False):
|
|
"""
|
|
Dispatch-trace the branch with inputs and check if
|
|
producing graph has mutable op on the input. This is
|
|
bit restrictive as the branch must be traceable.
|
|
"""
|
|
try:
|
|
gm = make_fx(branch, pre_dispatch=pre_dispatch)(*inputs)
|
|
except UnsupportedAliasMutationException:
|
|
# this can happen when nested cond_op is
|
|
# functionalized
|
|
return True
|
|
except Exception as e:
|
|
raise e
|
|
|
|
return _detect_input_mutation(gm)
|
|
|
|
|
|
def _has_potential_branch_input_alias(branch, inputs, pre_dispatch=False):
|
|
"""
|
|
Dispatch-trace the branch with inputs and check if
|
|
producing graph has output aliasing the branch input. This is
|
|
bit restrictive as the branch must be traceable.
|
|
"""
|
|
try:
|
|
gm = make_fx(branch, pre_dispatch=pre_dispatch)(*inputs)
|
|
except UnsupportedAliasMutationException:
|
|
# this can happen when nested cond_op is
|
|
# functionalized
|
|
return True
|
|
except Exception as e:
|
|
raise e
|
|
|
|
return _detect_input_alias(gm)
|
|
|
|
|
|
def unique_graph_id(proxy_mode, prefix):
|
|
"""Returns a unique name and id for a graph to be added to a proxy_mode tracer"""
|
|
# There are probably better ways - I know that create_arg has some self incrementing name
|
|
# magic to it, but since we explicitly have to get the name for register_module,
|
|
# I was not sure how to do that. This kinda simulates it.
|
|
next_name = None
|
|
i = 0
|
|
while not next_name:
|
|
candidate = f"{prefix}_{i}"
|
|
if hasattr(proxy_mode.tracer.root, candidate):
|
|
i += 1
|
|
else:
|
|
next_name = candidate
|
|
return i, next_name
|
|
|
|
|
|
def _from_fun(t):
|
|
from torch._functorch.aot_autograd import from_fun
|
|
from torch._subclasses.functional_tensor import FunctionalTensor
|
|
|
|
if isinstance(t, torch.Tensor):
|
|
if t.dtype != torch.bool:
|
|
return torch.empty_strided(
|
|
t.size(),
|
|
t.stride(),
|
|
dtype=t.dtype,
|
|
requires_grad=t.requires_grad,
|
|
device=t.device,
|
|
)
|
|
else:
|
|
# clone of a functional tensor produces a functional tensor
|
|
# but we want to avoid it so we clone a non-functional version
|
|
maybe_unfunc_t = t
|
|
if isinstance(t, FunctionalTensor):
|
|
torch._sync(t)
|
|
maybe_unfunc_t = from_fun(t)
|
|
elif torch._is_functional_tensor(t):
|
|
# need to handle both types of functionalization here:
|
|
# these are the tensors that came from the user,
|
|
# which could be either FunctionalTensorWrapper or FunctionalTensor
|
|
torch._sync(t)
|
|
maybe_unfunc_t = torch._from_functional_tensor(t)
|
|
return maybe_unfunc_t.clone()
|
|
return t
|
|
|
|
|
|
def clone_outputs_aliasing_inputs(args):
|
|
input_storage = {
|
|
StorageWeakRef(arg._typed_storage())
|
|
for arg in args
|
|
if isinstance(arg, torch.Tensor)
|
|
}
|
|
|
|
def maybe_clone(t):
|
|
if (
|
|
isinstance(t, torch.Tensor)
|
|
and StorageWeakRef(t._typed_storage()) in input_storage
|
|
):
|
|
return t.clone()
|
|
return t
|
|
|
|
return maybe_clone
|
|
|
|
|
|
def prepare_fw_with_masks(fn):
|
|
def fw_with_masks(*args):
|
|
fw_out = fn(*args)
|
|
return fw_out, [
|
|
True if isinstance(ret, torch.Tensor) and ret.requires_grad else False
|
|
for ret in fw_out
|
|
]
|
|
|
|
return fw_with_masks
|
|
|
|
|
|
# This function replaces None gradients with all-zero gradients.
|
|
# `None` gradients are problematic for CUDA graphs. Those gradients are
|
|
# replaced with an all-zero tensor for better optimization
|
|
def unmask_none_gradients(grads, operands):
|
|
allowed_types = (torch.Tensor, int, torch.SymInt)
|
|
assert all(
|
|
isinstance(o, allowed_types) for o in operands
|
|
), f"operands can only be of {allowed_types} but got {[type(o) for o in operands]}"
|
|
|
|
unmasked_grads = []
|
|
for g, o in zip(grads, operands):
|
|
if g is not None:
|
|
unmasked_grads.append(g)
|
|
else:
|
|
# In case the operand is an int or a torch.SymInt, return None
|
|
# This can happen for lifted_arguments. E.g., the shapes of a dynamic tensor are lifted and passed
|
|
# as additional arguments
|
|
unmasked_grads.append(
|
|
torch.zeros_like(o) if isinstance(o, torch.Tensor) else None
|
|
)
|
|
|
|
return unmasked_grads
|
|
|
|
|
|
# TODO: The parameter use_output_and_grad_bw is required because some operations
|
|
# that utilize this function, such as the while_loop, may require (grad, fwd_outputs)
|
|
def create_fw_bw_graph(fn, use_output_and_grad_bw, fw_inputs, fw_outputs):
|
|
from torch._functorch.aot_autograd import AOTConfig, create_joint
|
|
|
|
# Note:[HOP create fw_bw graph] We create "clean" environments for make_fx by suspending all dispatch keys
|
|
# between Autograd and Python key. Currently, we only suspend functionalization but more can be
|
|
# added when required. Will encounter two problems if we don't suspend functionalization:
|
|
#
|
|
# 1. make_fx fails to capture operations on input: the inputs are wrapped as _to_functional_tensor_wrapper,
|
|
# but they will be unwrapped before entering ProxyTorchDispatchMode as part of the dispatching.
|
|
# However, it's the outside wrapper that tracer creates proxies for. This casuses tracer fail to
|
|
# fetch the proxy for the inputs and fail to capture any operations on them.
|
|
#
|
|
# 2. make_fx fails to capture output: the outputs after ProxyTorchDispatchMode are further
|
|
# wrapped as FunctionalTensorWrapper in Functionalize key after return. However, the tracer
|
|
# only associates the inner tensor with proxy in ProxyTorchDispatchMode. Therefore,
|
|
# when creating the output node, it fails to associate the wrapped tensor with its proxy.
|
|
# Instead, it will create _tensor_constant as output.
|
|
|
|
dummy_aot_config = AOTConfig(
|
|
fw_compiler=None, # type: ignore[arg-type]
|
|
bw_compiler=None, # type: ignore[arg-type]
|
|
partition_fn=None, # type: ignore[arg-type]
|
|
decompositions={},
|
|
num_params_buffers=0,
|
|
aot_id=0,
|
|
keep_inference_input_mutations=False,
|
|
)
|
|
|
|
example_grad = [_from_fun(out) for out in fw_outputs]
|
|
num_grads = len(example_grad)
|
|
fw_graph = _maybe_reenter_make_fx(fn)(*fw_inputs)
|
|
|
|
def joint_fn(*joint_operands_grads):
|
|
if use_output_and_grad_bw:
|
|
grads = joint_operands_grads[0]
|
|
inputs = joint_operands_grads[1][-1:]
|
|
else:
|
|
grads = joint_operands_grads[:num_grads]
|
|
inputs = joint_operands_grads[num_grads:]
|
|
|
|
joint = create_joint(prepare_fw_with_masks(fn), aot_config=dummy_aot_config)
|
|
_, grads = joint(
|
|
list(inputs),
|
|
[grad for grad in grads if grad is not None and grad.requires_grad],
|
|
)
|
|
|
|
# Unmask None gradients to all-zero gradients
|
|
unmasked_grads = unmask_none_gradients(grads, inputs)
|
|
|
|
# In order to keep map functional for backward graph,
|
|
# we clone outputs that are aliasing inputs
|
|
maybe_clone = clone_outputs_aliasing_inputs(joint_operands_grads)
|
|
|
|
return pytree.tree_map(maybe_clone, unmasked_grads)
|
|
|
|
if use_output_and_grad_bw:
|
|
example_xs_out = list(fw_inputs) + list(fw_outputs)
|
|
joint_graph = _maybe_reenter_make_fx(joint_fn)(
|
|
(list(example_grad), list(example_xs_out))
|
|
)
|
|
else:
|
|
example_xs_out = list(fw_inputs)
|
|
joint_graph = _maybe_reenter_make_fx(joint_fn)(
|
|
*(list(example_grad) + list(example_xs_out))
|
|
)
|
|
|
|
return fw_graph, joint_graph
|
|
|
|
|
|
def _unstack_pytree(xs):
|
|
flat_xs, inspec = pytree.tree_flatten(xs)
|
|
if not all(isinstance(xs, torch.Tensor) for xs in flat_xs):
|
|
raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}")
|
|
|
|
if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs):
|
|
raise RuntimeError(
|
|
f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}"
|
|
)
|
|
|
|
a = zip(*flat_xs)
|
|
|
|
pytrees = [pytree.tree_unflatten(tuple, inspec) for tuple in a]
|
|
return pytrees
|
|
|
|
|
|
def _stack_pytree(pytrees):
|
|
flat_out = []
|
|
out_spec = None
|
|
for pt in pytrees:
|
|
flat_pt, out_spec = pytree.tree_flatten(pt)
|
|
flat_out.append(flat_pt)
|
|
assert out_spec is not None
|
|
b = zip(*flat_out)
|
|
stacked_out = []
|
|
for leaves in b:
|
|
if all(isinstance(leaf, torch.Tensor) for leaf in leaves):
|
|
stacked_out.append(torch.stack(leaves))
|
|
elif all(leaf is None for leaf in leaves):
|
|
# Backward graph can return None output when forward inputs doesn't require grad.
|
|
# When we eagerly execute backward graph, we need to call _stack_pytree on its output,
|
|
# therefore we need to deal with None output.
|
|
stacked_out.append(None) # type: ignore[arg-type]
|
|
else:
|
|
raise RuntimeError(f"Cannot stack {leaves}.")
|
|
return pytree.tree_unflatten(stacked_out, out_spec)
|
|
|
|
|
|
# We cannot call save_for_backward for symints. This helper function
|
|
# can be used to save symints as direct attributes of ctx in autograd.Function.
|
|
#
|
|
# For example, if args = (x, y, s0, z, s1),
|
|
# save_tensors_and_symints_for_backward will partition the args into two lists, and a bookkeeping list pos:
|
|
# partitioned_args[0] = (x, y, z)
|
|
# partitioned_args[1] = (s0, s1)
|
|
# pos = (0, 0, 1, 0, 1)
|
|
# pos list keeps track of which partition the args
|
|
# is partitioned into in order to recover it in saved_tensors_and_symints.
|
|
#
|
|
# In saved_tensors_and_symints, we can recover the original args by:
|
|
# iterating over the pos list and pop one item from the front of paritioned_args[pos[i]].
|
|
# We use t_idx and s_idx to keep track of the next index of the item we are going to pop for the two lists.
|
|
def save_tensors_and_symints_for_backward(ctx, args):
|
|
assert all(
|
|
isinstance(arg, (torch.Tensor, torch.SymInt, int, type(None))) for arg in args
|
|
), args
|
|
partitioned_args: list[Any] = [[], []]
|
|
pos = []
|
|
for i, arg in enumerate(args):
|
|
idx = 0 if isinstance(arg, torch.Tensor) else 1
|
|
partitioned_args[idx].append(arg)
|
|
pos.append(idx)
|
|
|
|
assert not hasattr(ctx, "sym_int_args"), "ctx already has sym_int_args attribute."
|
|
assert not hasattr(ctx, "pos"), "ctx already has pos attribute."
|
|
ctx.save_for_backward(*partitioned_args[0])
|
|
ctx.sym_int_args = partitioned_args[1]
|
|
ctx.pos = pos
|
|
|
|
|
|
def saved_tensors_and_symints(ctx):
|
|
args = []
|
|
t_idx = 0
|
|
s_idx = 0
|
|
saved_tensors = ctx.saved_tensors
|
|
for p in ctx.pos:
|
|
if p == 0:
|
|
args.append(saved_tensors[t_idx])
|
|
t_idx += 1
|
|
else:
|
|
args.append(ctx.sym_int_args[s_idx])
|
|
s_idx += 1
|
|
assert t_idx + s_idx == len(ctx.pos)
|
|
return tuple(args)
|
|
|
|
|
|
def get_dummy_aot_autograd_config():
|
|
from torch._functorch.aot_autograd import AOTConfig
|
|
|
|
return AOTConfig(
|
|
fw_compiler=None, # type: ignore[arg-type]
|
|
bw_compiler=None, # type: ignore[arg-type]
|
|
partition_fn=None, # type: ignore[arg-type]
|
|
decompositions={},
|
|
num_params_buffers=0,
|
|
aot_id=0,
|
|
keep_inference_input_mutations=False,
|
|
)
|
|
|
|
|
|
# Slices off the first element of a given dimension
|
|
def first_slice_copy(t: torch.Tensor, dim: int = 0) -> torch.Tensor:
|
|
return torch.select_copy(t, dim, 0)
|
|
|
|
|
|
# Reports the difference between meta of two tensors in a string
|
|
def diff_tensor_meta(
|
|
meta1: TensorMetadata, meta2: TensorMetadata, check_grad=True
|
|
) -> list[str]:
|
|
from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode
|
|
|
|
pair_diffs = []
|
|
for meta_name in TensorMetadata._fields:
|
|
if not check_grad and meta_name == "requires_grad":
|
|
continue
|
|
val1 = getattr(meta1, meta_name)
|
|
val2 = getattr(meta2, meta_name)
|
|
try:
|
|
if val1 != val2:
|
|
pair_diffs.append(f"'{meta_name}: {val1} vs {val2}'")
|
|
except GuardOnDataDependentSymNode as _:
|
|
pair_diffs.append(f"'{meta_name}: {val1} vs {val2}'")
|
|
continue
|
|
return pair_diffs
|
|
|
|
|
|
# Note [lifted arg types in hop]
|
|
# For dynamoed hops, we automatically lift the free symbols in tensors as arguments.
|
|
# This has implications for the types of lifted args for different dispatch keys:
|
|
# 1. functionalization, FakeTensorMode, ProxyTorchDispatchMode, Autograd need to support torch.Symint
|
|
# lifted args because it's on the path of torch.compile(dynamic=True).
|
|
# 2. functionalization, FakeTensorMode, ProxyTorchDispatchMode, Autograd, CompositeExplicitAutograd need
|
|
# to support int arguments. In the eager run case, we re-trace the subgraph in AutogradKey, so inner
|
|
# hops may receive int inputs from the shape of outer tensor inputs.
|
|
# However, CompositeExplicitAutograd won't receive SymInt inputs because it only accepts real tensor inputs.
|
|
def validate_subgraph_args_types(lifted_args: Union[tuple[Any, ...], list[Any]]):
|
|
allowed_types = (torch.Tensor, int, torch.SymInt)
|
|
assert all(
|
|
isinstance(arg, (torch.Tensor, int, torch.SymInt)) for arg in lifted_args
|
|
), f"{lifted_args} can only be of {allowed_types} but got {tuple(type(arg) for arg in lifted_args)}"
|