mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
The idea is the parent hop's fake tensor mode should ignore the newly allocated unbacked symints in subgraph because the bindings of unbacked symbols in the subgraph should already be done when we trace the subgraph. E.g. if there's an operator in subgraph that produces unbacked symints, the track_tensor_tree logic for that operator will take care of it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/142031 Approved by: https://github.com/zou3519 ghstack dependencies: #142162
520 lines
19 KiB
Python
520 lines
19 KiB
Python
# mypy: allow-untyped-defs
|
|
import functools
|
|
from contextlib import contextmanager
|
|
from dataclasses import dataclass
|
|
from typing import Any, Callable, List, Tuple, Union
|
|
|
|
import torch
|
|
import torch.fx.traceback as fx_traceback
|
|
import torch.utils._pytree as pytree
|
|
from torch._ops import OperatorBase
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch.fx.passes.shape_prop import TensorMetadata
|
|
from torch.multiprocessing.reductions import StorageWeakRef
|
|
|
|
|
|
@dataclass
|
|
class UnsupportedAliasMutationException(RuntimeError):
|
|
reason: str
|
|
|
|
|
|
def autograd_not_implemented_inner(
|
|
operator: OperatorBase, delayed_error: bool, *args: Any, **kwargs: Any
|
|
) -> Any:
|
|
"""If autograd is enabled and any of the arguments require grad this will either
|
|
raise an error or return a DelayedError depending on the value of delayed.
|
|
|
|
Args:
|
|
operator: The Operator to call with the *args and **kwargs with
|
|
op_name: The name of the Operator
|
|
delayed_error: If True, return a DelayedError instead of raising an error
|
|
args: The flattened operands to the Operator
|
|
kwargs: The keyword arguments to the Operator
|
|
|
|
Raises:
|
|
RuntimeError: If autograd is enabled and any of the arguments to the Operator
|
|
"""
|
|
with torch._C._AutoDispatchBelowAutograd():
|
|
result = operator(*args, **kwargs)
|
|
flat_operands = pytree.arg_tree_leaves(*args)
|
|
if torch.is_grad_enabled() and any(
|
|
f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor)
|
|
):
|
|
if delayed_error:
|
|
err_fn = torch._C._functions.DelayedError(
|
|
f"Autograd not implemented for {str(operator)}",
|
|
1,
|
|
)
|
|
|
|
def fake_requires_grad(tensor):
|
|
if torch.is_floating_point(tensor) or torch.is_complex(tensor):
|
|
tensor = tensor.detach()
|
|
tensor.requires_grad = True
|
|
return tensor
|
|
|
|
return pytree.tree_map_only(
|
|
torch.Tensor, lambda x: err_fn(fake_requires_grad(x)), result
|
|
)
|
|
else:
|
|
raise RuntimeError(f"Autograd not implemented for {str(operator)}")
|
|
return result
|
|
|
|
|
|
def autograd_not_implemented(op: OperatorBase, deferred_error: bool) -> Callable:
|
|
def inner(*args, **kwargs):
|
|
return autograd_not_implemented_inner(op, deferred_error, *args, **kwargs)
|
|
|
|
return inner
|
|
|
|
|
|
def _maybe_run_with_interpreter(fn):
|
|
maybe_interpreted_fn = fn
|
|
if isinstance(fn, torch.fx.GraphModule) and fx_traceback.has_preserved_node_meta():
|
|
# Running graph with interpreter is needed for propagating the stack_trace
|
|
def graph_with_interpreter(*args):
|
|
with fx_traceback.preserve_node_meta():
|
|
return torch.fx.Interpreter(fn).run(*args)
|
|
|
|
maybe_interpreted_fn = graph_with_interpreter
|
|
return maybe_interpreted_fn
|
|
|
|
|
|
def reenter_make_fx(fn):
|
|
from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER
|
|
|
|
@functools.wraps(fn)
|
|
def wrapped(*args):
|
|
assert (
|
|
_CURRENT_MAKE_FX_TRACER is not None
|
|
), "Cannot reenter make_fx when we're not under a make_fx tracing session"
|
|
return _CURRENT_MAKE_FX_TRACER.trace_subgraph(
|
|
_maybe_run_with_interpreter(fn), *args
|
|
)
|
|
|
|
return wrapped
|
|
|
|
|
|
def _maybe_reenter_make_fx(fn):
|
|
from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER
|
|
|
|
if _CURRENT_MAKE_FX_TRACER is not None:
|
|
return reenter_make_fx(fn)
|
|
else:
|
|
|
|
def _maybe_make_fx_with_fake_mode(fn):
|
|
@functools.wraps(fn)
|
|
def wrapped(*args):
|
|
from torch._guards import detect_fake_mode
|
|
|
|
fake_mode = detect_fake_mode(args)
|
|
if fake_mode is None:
|
|
# we creaeta a fake_mode here to make sure we could
|
|
# trace the graph with data-dependent calls e.g. .item()
|
|
return make_fx(fn, tracing_mode="fake")(*args)
|
|
# Tracing with real if all inputs have been fakfied
|
|
return make_fx(fn)(*args)
|
|
|
|
return wrapped
|
|
|
|
return _maybe_make_fx_with_fake_mode(fn)
|
|
|
|
|
|
@contextmanager
|
|
def _set_compilation_env():
|
|
_old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag
|
|
try:
|
|
# We need to turn off the is_fx_tracing_flag. Remove this flag check from dyanmo
|
|
# once we are confident fx tracing works with dynamo.
|
|
torch.fx._symbolic_trace._is_fx_tracing_flag = False
|
|
yield
|
|
finally:
|
|
torch.fx._symbolic_trace._is_fx_tracing_flag = _old_is_tracing
|
|
|
|
|
|
def _detect_input_mutation(gm):
|
|
input_nodes = set()
|
|
for node in gm.graph.nodes:
|
|
if node.op == "placeholder":
|
|
input_nodes.add(node)
|
|
if node.op == "call_function":
|
|
target = node.target
|
|
if isinstance(target, torch._ops.OpOverload) and target._schema.is_mutable:
|
|
for arg in node.args:
|
|
if arg in input_nodes:
|
|
return True
|
|
|
|
for _, module in gm.named_children():
|
|
if isinstance(module, torch.fx.GraphModule):
|
|
if _detect_input_mutation(module):
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def _detect_input_alias(gm):
|
|
input_storages = set()
|
|
for node in gm.graph.nodes:
|
|
# We need to check existence of "val" because we reuse the logic here
|
|
# for map operator, where num_mapped_args is a scalar
|
|
# and doesn't have a "val" meta.
|
|
if (
|
|
node.op == "placeholder"
|
|
and "val" in node.meta
|
|
and isinstance(node.meta["val"], torch.Tensor)
|
|
):
|
|
input_storages.add(StorageWeakRef(node.meta["val"]._typed_storage()))
|
|
if node.op == "output":
|
|
|
|
def check_alias(out):
|
|
if (
|
|
out is not None
|
|
and "val" in out.meta
|
|
and isinstance(out.meta["val"], torch.Tensor)
|
|
):
|
|
out_storage = StorageWeakRef(out.meta["val"]._typed_storage())
|
|
return out_storage in input_storages
|
|
return False
|
|
|
|
if any(pytree.tree_leaves(pytree.tree_map(check_alias, node.args))):
|
|
return True
|
|
|
|
for _, module in gm.named_children():
|
|
if isinstance(module, torch.fx.GraphModule) and _detect_input_alias(module):
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def has_potential_input_alias_or_mutation(gm, inputs, pre_dispatch=False):
|
|
try:
|
|
gm = make_fx(gm, pre_dispatch=pre_dispatch)(*inputs)
|
|
except UnsupportedAliasMutationException:
|
|
# this can happen when nested cond_op is
|
|
# functionalized
|
|
return True
|
|
except Exception as e:
|
|
raise e
|
|
|
|
return _detect_input_mutation(gm) or _detect_input_alias(gm)
|
|
|
|
|
|
def _has_potential_branch_input_mutation(branch, inputs, pre_dispatch=False):
|
|
"""
|
|
Dispatch-trace the branch with inputs and check if
|
|
producing graph has mutable op on the input. This is
|
|
bit restrictive as the branch must be traceable.
|
|
"""
|
|
try:
|
|
gm = make_fx(branch, pre_dispatch=pre_dispatch)(*inputs)
|
|
except UnsupportedAliasMutationException:
|
|
# this can happen when nested cond_op is
|
|
# functionalized
|
|
return True
|
|
except Exception as e:
|
|
raise e
|
|
|
|
return _detect_input_mutation(gm)
|
|
|
|
|
|
def _has_potential_branch_input_alias(branch, inputs, pre_dispatch=False):
|
|
"""
|
|
Dispatch-trace the branch with inputs and check if
|
|
producing graph has output aliasing the branch input. This is
|
|
bit restrictive as the branch must be traceable.
|
|
"""
|
|
try:
|
|
gm = make_fx(branch, pre_dispatch=pre_dispatch)(*inputs)
|
|
except UnsupportedAliasMutationException:
|
|
# this can happen when nested cond_op is
|
|
# functionalized
|
|
return True
|
|
except Exception as e:
|
|
raise e
|
|
|
|
return _detect_input_alias(gm)
|
|
|
|
|
|
def unique_graph_id(proxy_mode, prefix):
|
|
"""Returns a unique name and id for a graph to be added to a proxy_mode tracer"""
|
|
# There are probably better ways - I know that create_arg has some self incrementing name
|
|
# magic to it, but since we explicitly have to get the name for register_module,
|
|
# I was not sure how to do that. This kinda simulates it.
|
|
next_name = None
|
|
i = 0
|
|
while not next_name:
|
|
candidate = f"{prefix}_{i}"
|
|
if hasattr(proxy_mode.tracer.root, candidate):
|
|
i += 1
|
|
else:
|
|
next_name = candidate
|
|
return i, next_name
|
|
|
|
|
|
def _from_fun(t):
|
|
from torch._functorch.aot_autograd import from_fun
|
|
from torch._subclasses.functional_tensor import FunctionalTensor
|
|
|
|
if isinstance(t, torch.Tensor):
|
|
if t.dtype != torch.bool:
|
|
return torch.empty_strided(
|
|
t.size(),
|
|
t.stride(),
|
|
dtype=t.dtype,
|
|
requires_grad=t.requires_grad,
|
|
device=t.device,
|
|
)
|
|
else:
|
|
# clone of a functional tensor produces a functional tensor
|
|
# but we want to avoid it so we clone a non-functional version
|
|
maybe_unfunc_t = t
|
|
if isinstance(t, FunctionalTensor):
|
|
torch._sync(t)
|
|
maybe_unfunc_t = from_fun(t)
|
|
elif torch._is_functional_tensor(t):
|
|
# need to handle both types of functionalization here:
|
|
# these are the tensors that came from the user,
|
|
# which could be either FunctionalTensorWrapper or FunctionalTensor
|
|
torch._sync(t)
|
|
maybe_unfunc_t = torch._from_functional_tensor(t)
|
|
return maybe_unfunc_t.clone()
|
|
return t
|
|
|
|
|
|
def clone_outputs_aliasing_inputs(args):
|
|
input_storage = {
|
|
StorageWeakRef(arg._typed_storage())
|
|
for arg in args
|
|
if isinstance(arg, torch.Tensor)
|
|
}
|
|
|
|
def maybe_clone(t):
|
|
if (
|
|
isinstance(t, torch.Tensor)
|
|
and StorageWeakRef(t._typed_storage()) in input_storage
|
|
):
|
|
return t.clone()
|
|
return t
|
|
|
|
return maybe_clone
|
|
|
|
|
|
def prepare_fw_with_masks(fn):
|
|
def fw_with_masks(*args):
|
|
fw_out = fn(*args)
|
|
return fw_out, [
|
|
True if isinstance(ret, torch.Tensor) and ret.requires_grad else False
|
|
for ret in fw_out
|
|
]
|
|
|
|
return fw_with_masks
|
|
|
|
|
|
# TODO: The parameter use_output_and_grad_bw is required because some operations
|
|
# that utilize this function, such as the while_loop, may require (grad, fwd_outputs)
|
|
def create_fw_bw_graph(fn, use_output_and_grad_bw, fw_inputs, fw_outputs):
|
|
from torch._functorch.aot_autograd import AOTConfig, create_joint
|
|
|
|
# Note:[HOP create fw_bw graph] We create "clean" environments for make_fx by suspending all dispatch keys
|
|
# between Autograd and Python key. Currently, we only suspend functionalization but more can be
|
|
# added when required. Will encounter two problems if we don't suspend functionalization:
|
|
#
|
|
# 1. make_fx fails to capture operations on input: the inputs are wrapped as _to_functional_tensor_wrapper,
|
|
# but they will be unwrapped before entering ProxyTorchDispatchMode as part of the dispatching.
|
|
# However, it's the outside wrapper that tracer creates proxies for. This casuses tracer fail to
|
|
# fetch the proxy for the inputs and fail to capture any operations on them.
|
|
#
|
|
# 2. make_fx fails to capture output: the outputs after ProxyTorchDispatchMode are further
|
|
# wrapped as FunctionalTensorWrapper in Functionalize key after return. However, the tracer
|
|
# only associates the inner tensor with proxy in ProxyTorchDispatchMode. Therefore,
|
|
# when creating the output node, it fails to associate the wrapped tensor with its proxy.
|
|
# Instead, it will create _tensor_constant as output.
|
|
|
|
dummy_aot_config = AOTConfig(
|
|
fw_compiler=None, # type: ignore[arg-type]
|
|
bw_compiler=None, # type: ignore[arg-type]
|
|
partition_fn=None, # type: ignore[arg-type]
|
|
decompositions={},
|
|
num_params_buffers=0,
|
|
aot_id=0,
|
|
keep_inference_input_mutations=False,
|
|
)
|
|
|
|
example_grad = [_from_fun(out) for out in fw_outputs]
|
|
num_grads = len(example_grad)
|
|
fw_graph = _maybe_reenter_make_fx(fn)(*fw_inputs)
|
|
|
|
def joint_fn(*joint_operands_grads):
|
|
if use_output_and_grad_bw:
|
|
grads = joint_operands_grads[0]
|
|
inputs = joint_operands_grads[1][-1:]
|
|
else:
|
|
grads = joint_operands_grads[:num_grads]
|
|
inputs = joint_operands_grads[num_grads:]
|
|
|
|
joint = create_joint(prepare_fw_with_masks(fn), aot_config=dummy_aot_config)
|
|
_, grads = joint(
|
|
list(inputs),
|
|
[grad for grad in grads if grad is not None and grad.requires_grad],
|
|
)
|
|
|
|
# In order to keep map functional for backward graph,
|
|
# we clone outputs that are aliasing inputs
|
|
maybe_clone = clone_outputs_aliasing_inputs(joint_operands_grads)
|
|
|
|
return pytree.tree_map(maybe_clone, grads)
|
|
|
|
if use_output_and_grad_bw:
|
|
example_xs_out = list(fw_inputs) + list(fw_outputs)
|
|
joint_graph = _maybe_reenter_make_fx(joint_fn)(
|
|
(list(example_grad), list(example_xs_out))
|
|
)
|
|
else:
|
|
example_xs_out = list(fw_inputs)
|
|
joint_graph = _maybe_reenter_make_fx(joint_fn)(
|
|
*(list(example_grad) + list(example_xs_out))
|
|
)
|
|
|
|
return fw_graph, joint_graph
|
|
|
|
|
|
def _unstack_pytree(xs):
|
|
flat_xs, inspec = pytree.tree_flatten(xs)
|
|
if not all(isinstance(xs, torch.Tensor) for xs in flat_xs):
|
|
raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}")
|
|
|
|
if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs):
|
|
raise RuntimeError(
|
|
f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}"
|
|
)
|
|
|
|
a = zip(*flat_xs)
|
|
|
|
pytrees = [pytree.tree_unflatten(tuple, inspec) for tuple in a]
|
|
return pytrees
|
|
|
|
|
|
def _stack_pytree(pytrees):
|
|
flat_out = []
|
|
out_spec = None
|
|
for pt in pytrees:
|
|
flat_pt, out_spec = pytree.tree_flatten(pt)
|
|
flat_out.append(flat_pt)
|
|
assert out_spec is not None
|
|
b = zip(*flat_out)
|
|
stacked_out = []
|
|
for leaves in b:
|
|
if all(isinstance(leaf, torch.Tensor) for leaf in leaves):
|
|
stacked_out.append(torch.stack(leaves))
|
|
elif all(leaf is None for leaf in leaves):
|
|
# Backward graph can return None output when forward inputs doesn't require grad.
|
|
# When we eagerly execute backward graph, we need to call _stack_pytree on its output,
|
|
# therefore we need to deal with None output.
|
|
stacked_out.append(None) # type: ignore[arg-type]
|
|
else:
|
|
raise RuntimeError(f"Cannot stack {leaves}.")
|
|
return pytree.tree_unflatten(stacked_out, out_spec)
|
|
|
|
|
|
# We cannot call save_for_backward for symints. This helper function
|
|
# can be used to save symints as direct attributes of ctx in autograd.Function.
|
|
#
|
|
# For example, if args = (x, y, s0, z, s1),
|
|
# save_tensors_and_symints_for_backward will partition the args into two lists, and a bookkeeping list pos:
|
|
# partitioned_args[0] = (x, y, z)
|
|
# partitioned_args[1] = (s0, s1)
|
|
# pos = (0, 0, 1, 0, 1)
|
|
# pos list keeps track of which partition the args
|
|
# is partitioned into in order to recover it in saved_tensors_and_symints.
|
|
#
|
|
# In saved_tensors_and_symints, we can recover the original args by:
|
|
# iterating over the pos list and pop one item from the front of paritioned_args[pos[i]].
|
|
# We use t_idx and s_idx to keep track of the next index of the item we are going to pop for the two lists.
|
|
def save_tensors_and_symints_for_backward(ctx, args):
|
|
assert all(
|
|
isinstance(arg, (torch.Tensor, torch.SymInt, int, type(None))) for arg in args
|
|
), args
|
|
partitioned_args: List[Any] = [[], []]
|
|
pos = []
|
|
for i, arg in enumerate(args):
|
|
idx = 0 if isinstance(arg, torch.Tensor) else 1
|
|
partitioned_args[idx].append(arg)
|
|
pos.append(idx)
|
|
|
|
assert not hasattr(ctx, "sym_int_args"), "ctx already has sym_int_args attribute."
|
|
assert not hasattr(ctx, "pos"), "ctx already has pos attribute."
|
|
ctx.save_for_backward(*partitioned_args[0])
|
|
ctx.sym_int_args = partitioned_args[1]
|
|
ctx.pos = pos
|
|
|
|
|
|
def saved_tensors_and_symints(ctx):
|
|
args = []
|
|
t_idx = 0
|
|
s_idx = 0
|
|
saved_tensors = ctx.saved_tensors
|
|
for p in ctx.pos:
|
|
if p == 0:
|
|
args.append(saved_tensors[t_idx])
|
|
t_idx += 1
|
|
else:
|
|
args.append(ctx.sym_int_args[s_idx])
|
|
s_idx += 1
|
|
assert t_idx + s_idx == len(ctx.pos)
|
|
return tuple(args)
|
|
|
|
|
|
def get_dummy_aot_autograd_config():
|
|
from torch._functorch.aot_autograd import AOTConfig
|
|
|
|
return AOTConfig(
|
|
fw_compiler=None, # type: ignore[arg-type]
|
|
bw_compiler=None, # type: ignore[arg-type]
|
|
partition_fn=None, # type: ignore[arg-type]
|
|
decompositions={},
|
|
num_params_buffers=0,
|
|
aot_id=0,
|
|
keep_inference_input_mutations=False,
|
|
)
|
|
|
|
|
|
# Slices off the first element of a given dimension
|
|
def first_slice_copy(t: torch.Tensor, dim: int = 0) -> torch.Tensor:
|
|
return torch.select_copy(t, dim, 0)
|
|
|
|
|
|
# Reports the difference between meta of two tensors in a string
|
|
def diff_tensor_meta(
|
|
meta1: TensorMetadata, meta2: TensorMetadata, check_grad=True
|
|
) -> List[str]:
|
|
from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode
|
|
|
|
pair_diffs = []
|
|
for meta_name in TensorMetadata._fields:
|
|
if not check_grad and meta_name == "requires_grad":
|
|
continue
|
|
val1 = getattr(meta1, meta_name)
|
|
val2 = getattr(meta2, meta_name)
|
|
try:
|
|
if val1 != val2:
|
|
pair_diffs.append(f"'{meta_name}: {val1} vs {val2}'")
|
|
except GuardOnDataDependentSymNode as _:
|
|
pair_diffs.append(f"'{meta_name}: {val1} vs {val2}'")
|
|
continue
|
|
return pair_diffs
|
|
|
|
|
|
# Note [lifted arg types in hop]
|
|
# For dynamoed hops, we automatically lift the free symbols in tensors as arguments.
|
|
# This has implications for the types of lifted args for different dispatch keys:
|
|
# 1. functionalization, FakeTensorMode, ProxyTorchDispatchMode, Autograd need to support torch.Symint
|
|
# lifted args because it's on the path of torch.compile(dynamic=True).
|
|
# 2. functionalization, FakeTensorMode, ProxyTorchDispatchMode, Autograd, CompositeExplicitAutograd need
|
|
# to support int arguments. In the eager run case, we re-trace the subgraph in AutogradKey, so inner
|
|
# hops may receive int inputs from the shape of outer tensor inputs.
|
|
# However, CompositeExplicitAutograd won't receive SymInt inputs because it only accepts real tensor inputs.
|
|
def validate_subgraph_args_types(lifted_args: Union[Tuple[Any, ...], List[Any]]):
|
|
allowed_types = (torch.Tensor, int, torch.SymInt)
|
|
assert all(
|
|
isinstance(arg, (torch.Tensor, int, torch.SymInt)) for arg in lifted_args
|
|
), f"{lifted_args} can only be of {allowed_types} but got {tuple(type(arg) for arg in lifted_args)}"
|