Files
pytorch/torch/_higher_order_ops/map.py
Yidi Wu 38fead8f7c [hop] preserve metadata in re-tracing hop subgraph by running with interpreter (#135159)
In this way, the interpreter.run can preserve the current metadata of subgraphs correctly when tracing the subgraphs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135159
Approved by: https://github.com/tugsbayasgalan
2024-09-05 21:36:56 +00:00

265 lines
8.6 KiB
Python

# mypy: allow-untyped-defs
import torch
import torch.utils._pytree as pytree
from torch._C import DispatchKey
from torch._dispatch.python import suspend_functionalization
from torch._functorch.aot_autograd import AOTConfig, create_joint
from torch._higher_order_ops.utils import (
_has_potential_branch_input_alias,
_has_potential_branch_input_mutation,
_maybe_run_with_interpreter,
reenter_make_fx,
UnsupportedAliasMutationException,
)
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch._subclasses.functional_tensor import disable_functional_mode
from torch.fx.experimental.proxy_tensor import (
disable_proxy_modes_tracing,
make_fx,
ProxyTorchDispatchMode,
track_tensor_tree,
)
from .utils import (
_from_fun,
_stack_pytree,
_unstack_pytree,
clone_outputs_aliasing_inputs,
prepare_fw_with_masks,
)
# TODO: We add this to prevent dymamo from tracing into map_wrapper,
# remove the wrapper call when it's ready.
class MapWrapper(HigherOrderOperator):
def __init__(self):
super().__init__("map")
def __call__(self, xs, *args):
return map_wrapper(xs, *args)
class MapImpl(HigherOrderOperator):
def __init__(self):
super().__init__("map_impl")
def __call__(self, *args, **kwargs):
return super().__call__(*args, **kwargs)
map = MapWrapper()
map_impl = MapImpl()
dummy_aot_config = AOTConfig(
fw_compiler=None, # type: ignore[arg-type]
bw_compiler=None, # type: ignore[arg-type]
partition_fn=None, # type: ignore[arg-type]
decompositions={},
num_params_buffers=0,
aot_id=0,
keep_inference_input_mutations=False,
)
def create_fw_bw_graph(f, num_mapped_args, *args):
mapped_xs = args[:num_mapped_args]
pos_args = args[num_mapped_args:]
# See Note [HOP create fw_bw graph] in create_fw_bw_graph in utils.py
with suspend_functionalization(), disable_functional_mode():
with disable_proxy_modes_tracing():
unwrapped_mapped_xs = pytree.tree_map(_from_fun, mapped_xs)
example_xs = _unstack_pytree(unwrapped_mapped_xs)[0]
example_pos_args = [
_from_fun(arg) if isinstance(arg, torch.Tensor) else arg
for arg in pos_args
]
example_flat_out = pytree.tree_map(
_from_fun, f(*example_xs, *example_pos_args)
)
if any(
not isinstance(out, torch.Tensor)
for out in example_flat_out
if out is not None
):
raise RuntimeError(
"Expect outputs of map only contains tensors or None. "
f"Got types {[type(out) for out in example_flat_out]}."
)
example_grad = [_from_fun(out) for out in example_flat_out]
fw_graph = make_fx(f)(*example_xs, *example_pos_args)
def joint_f(*example_args):
joint_mapped_args = example_args[:joint_num_mapped]
args = example_args[joint_num_mapped:]
mapped_input = joint_mapped_args[:num_mapped_args]
mapped_grads = joint_mapped_args[num_mapped_args:]
joint = create_joint(prepare_fw_with_masks(f), aot_config=dummy_aot_config)
_, grads = joint(
list(mapped_input) + list(args),
[
grad
for grad in mapped_grads
if grad is not None and grad.requires_grad
],
)
# In order to keep map functional for backward graph,
# we clone outputs that are aliasing inputs
maybe_clone = clone_outputs_aliasing_inputs(example_args)
return pytree.tree_map(maybe_clone, grads)
joint_num_mapped = len(example_grad) + len(example_xs)
joint_graph = make_fx(joint_f)(*example_xs, *example_grad, *example_pos_args)
return fw_graph, joint_graph
def map_wrapper(f, xs, *args):
flat_xs, xs_spec = pytree.tree_flatten(xs)
if not all(isinstance(t, torch.Tensor) for t in flat_xs):
raise RuntimeError(f"Mapped xs can only consist of tensors. Got xs {flat_xs}.")
num_mapped_args = len(flat_xs)
shapes = [xs.shape for xs in flat_xs]
leading_dim_size = shapes[0][0]
if leading_dim_size == 0:
raise RuntimeError("Leading dimensions of mapped xs cannot be 0.")
if any(cur_shape[0] != leading_dim_size for cur_shape in shapes):
raise RuntimeError(
f"Leading dimensions of mapped xs must be consistent. Got shapes {shapes}."
)
out_spec = None
def flat_fn(*flat_args):
xs = pytree.tree_unflatten(list(flat_args[:num_mapped_args]), xs_spec)
unflattened_out = f(xs, *flat_args[num_mapped_args:])
flat_out, tmp_out_spec = pytree.tree_flatten(unflattened_out)
nonlocal out_spec
out_spec = tmp_out_spec
return flat_out
return pytree.tree_unflatten(
map_impl(flat_fn, flat_xs, args), out_spec # type: ignore[arg-type]
)
class MapAutogradOp(torch.autograd.Function):
@staticmethod
def forward(ctx, fw_graph, joint_graph, num_mapped_args, *flat_args):
ctx.save_for_backward(*flat_args)
ctx._joint_graph = joint_graph
ctx._num_mapped_args = num_mapped_args
with torch._C._AutoDispatchBelowAutograd():
return (
*map_impl(
fw_graph, flat_args[:num_mapped_args], flat_args[num_mapped_args:]
),
)
@staticmethod
def backward(ctx, *flat_grads):
fw_args = ctx.saved_tensors
fw_mapped_args = fw_args[: ctx._num_mapped_args]
pos_args = fw_args[ctx._num_mapped_args :]
grads = map_impl(
ctx._joint_graph,
fw_mapped_args + flat_grads,
pos_args,
)
return None, None, None, *grads
def trace_map(proxy_mode, func_overload, f, xs, pos_args):
leading_dim_size = xs[0].shape[0]
example_input = _unstack_pytree(xs)[0]
body_graph = f
body_graph = reenter_make_fx(body_graph)(*example_input, *pos_args)
next_name = proxy_mode.tracer.get_fresh_qualname("body_graph_")
proxy_mode.tracer.root.register_module(next_name, body_graph)
with disable_proxy_modes_tracing():
example_outs = body_graph(*example_input, *pos_args)
def expand_tensor(t):
if isinstance(t, torch.Tensor):
return t.expand(leading_dim_size, *t.shape)
return t
expanded_outs = pytree.tree_map(expand_tensor, example_outs)
node_args = (body_graph, list(xs), list(pos_args))
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
out_proxy = proxy_mode.tracer.create_proxy(
"call_function", func_overload, proxy_args, {}, name="map_impl"
)
return track_tensor_tree(
expanded_outs, out_proxy, constant=None, tracer=proxy_mode.tracer
)
@map_impl.py_impl(DispatchKey.CompositeExplicitAutograd)
def map_dense(f, xs, pos_args):
pytrees = []
for inp in _unstack_pytree(xs):
pytrees.append(f(*inp, *pos_args))
return _stack_pytree(pytrees)
@map_impl.py_impl(DispatchKey.Autograd)
def map_autograd(f, xs, pos_args):
num_mapped_args = len(xs)
fw_graph, bw_graph = create_fw_bw_graph(f, num_mapped_args, *xs, *pos_args)
flat_out = MapAutogradOp.apply(fw_graph, bw_graph, num_mapped_args, *xs, *pos_args)
return flat_out
@map_impl.py_impl(ProxyTorchDispatchMode)
def map_proxy_torch_dispatch_mode(mode, f, xs, args):
return trace_map(mode, map_impl, f, xs, args)
@map_impl.py_impl(FakeTensorMode)
def map_fake_tensor_mode(mode, f, xs, args):
with mode:
return map_dense(f, xs, args)
@map_impl.py_functionalize_impl
def map_functionalize(ctx, f, xs, pos_args):
unwrapped_xs = ctx.unwrap_tensors(xs)
unwrapped_args = ctx.unwrap_tensors(pos_args)
wrapped_fn = ctx.functionalize(_maybe_run_with_interpreter(f))
with ctx.redispatch_to_next():
with disable_proxy_modes_tracing():
example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args)
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
if _has_potential_branch_input_mutation(
f, example_inputs, pre_dispatch=pre_dispatch
):
raise UnsupportedAliasMutationException("torch.map is mutating the input!")
if _has_potential_branch_input_alias(
f, example_inputs, pre_dispatch=pre_dispatch
):
raise UnsupportedAliasMutationException("torch.map is aliasing the input!")
map_return = map_impl(wrapped_fn, unwrapped_xs, unwrapped_args)
return ctx.wrap_tensors(map_return)