mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Support map autograd and pytree in/out. (#101633)
Rebased https://github.com/pytorch/pytorch/pull/100494 and added dummy AOTConfig. This PR adds autograd and pytree support for map operator. Implementation-wise: 1. We temporarily make two HigherOrderOperators, "map" and "map_impl": - "map" is user-facing. Currently, it unwraps the pytrees in inputs and create a flat_fn for it. Dynamo currently cannot deal with pytree.tree_flatten and pytree.tree_unflatten, we therefore make it a HigherOrderOperator to trigger dynamo logic of handling HigherOrderOperators. - "map_impl" is the actual operator that works with the rest of torch subsystems such as functionalization, make_fx. It accepts flattend arguments, and a num_mapped_args integer denoting how many of the flattend arguments need to mapped i.e. their first dimension will be unstacked. 2. We create the forward and backward graph in autograd key and call torch.autograd.Function. Currently, the backward graph is recomputation-based and we need to partition the joint graph in the future to be more efficient. Example traced graphs for map operators: ### Case 1: simple f and autograd ```python def f(x, y): return x + y def g(xs, y): out = control_flow.map(f, xs, y) return torch.autograd.grad(out, (xs, y), torch.ones_like(out)) gm = make_fx(g, tracing_mode="symbolic")(torch.ones(3, 4, 5, requires_grad=True), torch.ones(5, requires_grad=True)) # gm.print_readable() produces following: class g(torch.nn.Module): def forward(self, xs_1: f32[3, s1, s2], y_1: f32[s2]): # No stacktrace found for following nodes body_graph_0 = self.body_graph_0 map_impl = torch.ops.map_impl(body_graph_0, 1, xs_1, y_1); body_graph_0 = None getitem: f32[3, s1, s2] = map_impl[0]; map_impl = None ones_like: f32[3, s1, s2] = torch.ops.aten.ones_like.default(getitem, pin_memory = False) is_same_size = torch.ops.aten.is_same_size.default(getitem, ones_like); getitem = None body_graph_1 = self.body_graph_1 map_impl_1 = torch.ops.map_impl(body_graph_1, 2, xs_1, ones_like, y_1); body_graph_1 = xs_1 = ones_like = None getitem_1 = map_impl_1[0] getitem_2: f32[3, s1, s2] = map_impl_1[1] getitem_3: f32[3, s2] = map_impl_1[2]; map_impl_1 = None sum_1: f32[1, s2] = torch.ops.aten.sum.dim_IntList(getitem_3, [0], True); getitem_3 = None sym_size: Sym(s2) = torch.ops.aten.sym_size(y_1, 0); y_1 = None view: f32[s2] = torch.ops.aten.view.default(sum_1, [sym_size]); sum_1 = sym_size = None return (getitem_2, view) class <lambda>(torch.nn.Module): def forward(self, arg0_1, arg1_1: f32[s1, s2], arg2_1: f32[s2]): # No stacktrace found for following nodes add: f32[s1, s2] = torch.ops.aten.add.Tensor(arg1_1, arg2_1); arg1_1 = arg2_1 = None return [add] class <lambda>(torch.nn.Module): def forward(self, arg0_1, arg1_1: f32[s1, s2], arg2_1: f32[s1, s2], arg3_1: f32[s2]): # No stacktrace found for following nodes add: f32[s1, s2] = torch.ops.aten.add.Tensor(arg1_1, arg3_1); arg1_1 = None is_same_size = torch.ops.aten.is_same_size.default(add, arg2_1); add = None sum_1: f32[1, s2] = torch.ops.aten.sum.dim_IntList(arg2_1, [0], True) sym_size: Sym(s2) = torch.ops.aten.sym_size(arg3_1, 0); arg3_1 = None view: f32[s2] = torch.ops.aten.view.default(sum_1, [sym_size]); sum_1 = sym_size = None return [None, arg2_1, view] ``` ### Case 2: list input/output f and autograd ```python def f(x, y): return [x[0].cos() + y.sin(), x[1].sin() * y.cos()] def g(xs, y): out = control_flow.map(f, xs, y) flat_out, _ = pytree.tree_flatten(out) flat_inp, _ = pytree.tree_flatten((xs, y)) requires_grad_inp = [inp for inp in flat_inp if inp.requires_grad] return torch.autograd.grad(flat_out, requires_grad_inp, [torch.ones_like(out) for out in flat_out]) gm = make_fx(g, tracing_mode="symbolic")( [torch.ones(3, 4, 5), torch.ones(3, 4, 5, requires_grad=True)], torch.ones(5, requires_grad=True)) # gm.print_readable() produces following: class g(torch.nn.Module): def forward(self, xs, y): xs_1: f32[3, s1, s2], xs_2: f32[3, s1, s2], y_1: f32[s2], = fx_pytree.tree_flatten_spec([xs, y], self._in_spec) # No stacktrace found for following nodes body_graph_0 = self.body_graph_0 map_impl = torch.ops.map_impl(body_graph_0, 2, xs_1, xs_2, y_1); body_graph_0 = None getitem: f32[3, s1, s2] = map_impl[0] getitem_1: f32[3, s1, s2] = map_impl[1]; map_impl = None ones_like: f32[3, s1, s2] = torch.ops.aten.ones_like.default(getitem, pin_memory = False) ones_like_1: f32[3, s1, s2] = torch.ops.aten.ones_like.default(getitem_1, pin_memory = False) is_same_size = torch.ops.aten.is_same_size.default(getitem, ones_like); getitem = None is_same_size_1 = torch.ops.aten.is_same_size.default(getitem_1, ones_like_1); getitem_1 = None body_graph_1 = self.body_graph_1 map_impl_1 = torch.ops.map_impl(body_graph_1, 4, xs_1, xs_2, ones_like, ones_like_1, y_1); body_graph_1 = xs_1 = xs_2 = ones_like = ones_like_1 = None getitem_2 = map_impl_1[0] getitem_3 = map_impl_1[1] getitem_4: f32[3, s1, s2] = map_impl_1[2] getitem_5: f32[3, s2] = map_impl_1[3]; map_impl_1 = None sum_1: f32[1, s2] = torch.ops.aten.sum.dim_IntList(getitem_5, [0], True); getitem_5 = None sym_size: Sym(s2) = torch.ops.aten.sym_size(y_1, 0); y_1 = None view: f32[s2] = torch.ops.aten.view.default(sum_1, [sym_size]); sum_1 = sym_size = None return pytree.tree_unflatten([getitem_4, view], self._out_spec) class <lambda>(torch.nn.Module): def forward(self, arg0_1, arg1_1: f32[s1, s2], arg2_1: f32[s1, s2], arg3_1: f32[s2]): # No stacktrace found for following nodes cos: f32[s1, s2] = torch.ops.aten.cos.default(arg1_1); arg1_1 = None sin: f32[s2] = torch.ops.aten.sin.default(arg3_1) add: f32[s1, s2] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None sin_1: f32[s1, s2] = torch.ops.aten.sin.default(arg2_1); arg2_1 = None cos_1: f32[s2] = torch.ops.aten.cos.default(arg3_1); arg3_1 = None mul: f32[s1, s2] = torch.ops.aten.mul.Tensor(sin_1, cos_1); sin_1 = cos_1 = None return [add, mul] class <lambda>(torch.nn.Module): def forward(self, arg0_1, arg1_1: f32[s1, s2], arg2_1: f32[s1, s2], arg3_1: f32[s1, s2], arg4_1: f32[s1, s2], arg5_1: f32[s2]): # No stacktrace found for following nodes cos: f32[s1, s2] = torch.ops.aten.cos.default(arg1_1); arg1_1 = None sin: f32[s2] = torch.ops.aten.sin.default(arg5_1) add: f32[s1, s2] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None sin_1: f32[s1, s2] = torch.ops.aten.sin.default(arg2_1) cos_1: f32[s2] = torch.ops.aten.cos.default(arg5_1) mul: f32[s1, s2] = torch.ops.aten.mul.Tensor(sin_1, cos_1) is_same_size = torch.ops.aten.is_same_size.default(add, arg3_1); add = None is_same_size_1 = torch.ops.aten.is_same_size.default(mul, arg4_1); mul = None mul_1: f32[s1, s2] = torch.ops.aten.mul.Tensor(arg4_1, sin_1); sin_1 = None mul_2: f32[s1, s2] = torch.ops.aten.mul.Tensor(arg4_1, cos_1); arg4_1 = cos_1 = None sum_1: f32[1, s2] = torch.ops.aten.sum.dim_IntList(mul_1, [0], True); mul_1 = None sym_size: Sym(s2) = torch.ops.aten.sym_size(arg5_1, 0) view: f32[s2] = torch.ops.aten.view.default(sum_1, [sym_size]); sum_1 = None # sin_2: f32[s2] = torch.ops.aten.sin.default(arg5_1) neg: f32[s2] = torch.ops.aten.neg.default(sin_2); sin_2 = None mul_3: f32[s2] = torch.ops.aten.mul.Tensor(view, neg); view = neg = None cos_2: f32[s1, s2] = torch.ops.aten.cos.default(arg2_1); arg2_1 = None mul_4: f32[s1, s2] = torch.ops.aten.mul.Tensor(mul_2, cos_2); mul_2 = cos_2 = None sum_2: f32[1, s2] = torch.ops.aten.sum.dim_IntList(arg3_1, [0], True); arg3_1 = None view_1: f32[s2] = torch.ops.aten.view.default(sum_2, [sym_size]); sum_2 = sym_size = None cos_3: f32[s2] = torch.ops.aten.cos.default(arg5_1); arg5_1 = None mul_5: f32[s2] = torch.ops.aten.mul.Tensor(view_1, cos_3); view_1 = cos_3 = None add_1: f32[s2] = torch.ops.aten.add.Tensor(mul_3, mul_5); mul_3 = mul_5 = None return [None, None, mul_4, add_1] ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/101633 Approved by: https://github.com/zou3519
This commit is contained in:
@ -137,8 +137,10 @@ def inner(pred, true_fn, false_fn, operands):
|
||||
mode = _get_current_dispatch_mode()
|
||||
assert (mode is not None), "Mode should always be enabled for python fallback key"
|
||||
with _pop_mode_temporarily() as mode:
|
||||
res = trace_cond(mode, cond, pred, true_fn, false_fn, operands)
|
||||
return res
|
||||
if mode.enable_tracing:
|
||||
return trace_cond(mode, cond, pred, true_fn, false_fn, operands)
|
||||
else:
|
||||
return cond(pred, true_fn, false_fn, operands)
|
||||
|
||||
|
||||
@cond.py_impl(FakeTensorMode)
|
||||
@ -214,13 +216,19 @@ def _has_potential_branch_input_alias(branch, inputs):
|
||||
def _detect_input_alias(gm):
|
||||
input_storages = set()
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
# We need to check existence of "val" because we reuse the logic here
|
||||
# for map operator, where num_mapped_args is a scalar
|
||||
# and doesn't have a "val" meta.
|
||||
if node.op == "placeholder" and "val" in node.meta:
|
||||
input_storages.add(StorageWeakRef(node.meta['val']._typed_storage()))
|
||||
if node.op == "output":
|
||||
for out in node.args:
|
||||
out_storage = StorageWeakRef(out.meta["val"]._typed_storage())
|
||||
if out_storage in input_storages:
|
||||
return True
|
||||
def check_alias(out):
|
||||
if out is not None and "val" in out.meta:
|
||||
out_storage = StorageWeakRef(out.meta['val']._typed_storage())
|
||||
return out_storage in input_storages
|
||||
return False
|
||||
if any(pytree.tree_flatten(pytree.tree_map(check_alias, node.args))[0]):
|
||||
return True
|
||||
|
||||
for _, module in gm.named_children():
|
||||
if isinstance(module, torch.fx.GraphModule) and _detect_input_alias(module):
|
||||
|
@ -4,8 +4,10 @@ import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._C import DispatchKey, DispatchKeySet, ExcludeDispatchKeyGuard
|
||||
from torch._functorch.eager_transforms import _unwrap_all_tensors_from_functional, _wrap_all_tensors_to_functional, functionalize
|
||||
from torch._functorch.aot_autograd import create_joint, AOTConfig
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.multiprocessing.reductions import StorageWeakRef
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
disable_proxy_modes_tracing,
|
||||
make_fx,
|
||||
@ -17,23 +19,162 @@ from torch.utils._python_dispatch import (
|
||||
_get_current_dispatch_mode,
|
||||
_pop_mode_temporarily,
|
||||
)
|
||||
from torch.utils._pytree import tree_flatten
|
||||
from torch._dispatch.python import suspend_functionalization
|
||||
from ._cond import _has_potential_branch_input_alias, _has_potential_branch_input_mutation, UnsupportedAliasMutationException
|
||||
|
||||
|
||||
map = HigherOrderOperator("map")
|
||||
# TODO: We add this to prevent dymamo from tracing into map_wrapper,
|
||||
# remove the wrapper call when it's ready.
|
||||
class MapWrapper(HigherOrderOperator):
|
||||
def __call__(self, xs, *args):
|
||||
return map_wrapper(xs, *args)
|
||||
|
||||
map = MapWrapper("map")
|
||||
map_impl = HigherOrderOperator("map_impl")
|
||||
|
||||
dummy_aot_config = AOTConfig(fw_compiler=None,
|
||||
bw_compiler=None,
|
||||
partition_fn=None,
|
||||
decompositions={},
|
||||
num_params_buffers=0,
|
||||
aot_id=0,
|
||||
keep_inference_input_mutations=False)
|
||||
|
||||
|
||||
def trace_map(proxy_mode, func_overload, f, xs, *args):
|
||||
if not isinstance(xs, torch.Tensor):
|
||||
raise ValueError("map() must loop over a tensor")
|
||||
if len(xs.shape) == 0 or xs.shape[0] == 0:
|
||||
raise ValueError("map() cannot be traced with scalar tensors or zero dimension tensors")
|
||||
if not all(isinstance(o, torch.Tensor) for o in args):
|
||||
raise ValueError("map() operands must be a list of tensors or modules")
|
||||
def create_fw_bw_graph(f, num_mapped_args, *args):
|
||||
mapped_xs = args[:num_mapped_args]
|
||||
pos_args = args[num_mapped_args:]
|
||||
|
||||
# Note: We create "clean" environments for make_fx by suspending all dispatch keys
|
||||
# between Autograd and Python key. Currently, we only suspend functionalization but more can be
|
||||
# added when required. Will encounter two problems if we don't suspend functionalization:
|
||||
#
|
||||
# 1. make_fx fails to capture operations on input: the inputs are wrapped as _to_functional_tensor_wrapper,
|
||||
# but they will be unwrapped before entering ProxyTorchDispatchMode as part of the dispatching.
|
||||
# However, it's the outside wrapper that tracer creates proxies for. This casuses tracer fail to
|
||||
# fetch the proxy for the inputs and fail to capture any operations on them.
|
||||
#
|
||||
# 2. make_fx fails to capture output: the outputs after ProxyTorchDispatchMode are further
|
||||
# wrapped as FunctionalTensorWrapper in Functionalize key after return. However, the tracer
|
||||
# only associates the inner tensor with proxy in ProxyTorchDispatchMode. Therefore,
|
||||
# when creating the output node, it fails to associate the wrapped tensor with its proxy.
|
||||
# Instead, it will create _tensor_constant as output.
|
||||
|
||||
with suspend_functionalization():
|
||||
with disable_proxy_modes_tracing():
|
||||
def from_fun(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
return torch.empty_strided(t.size(), t.stride(), requires_grad=t.requires_grad)
|
||||
return t
|
||||
|
||||
example_xs = [from_fun(xs) for xs in _unstack_pytree(mapped_xs)[0]]
|
||||
example_pos_args = [from_fun(arg) if isinstance(arg, torch.Tensor) else arg for arg in pos_args]
|
||||
example_flat_out = pytree.tree_map(from_fun, f(*example_xs, *example_pos_args))
|
||||
if any(not isinstance(out, torch.Tensor) for out in example_flat_out if out is not None):
|
||||
raise RuntimeError("Expect outputs of map only contains tensors or None. "
|
||||
f"Got types {[type(out) for out in example_flat_out]}.")
|
||||
example_grad = [from_fun(out) for out in example_flat_out]
|
||||
|
||||
|
||||
fw_graph = make_fx(f)(*example_xs, *example_pos_args)
|
||||
|
||||
def joint_f(*example_args):
|
||||
joint_mapped_args = example_args[:joint_num_mapped]
|
||||
args = example_args[joint_num_mapped:]
|
||||
|
||||
mapped_input = joint_mapped_args[:num_mapped_args]
|
||||
mapped_grads = joint_mapped_args[num_mapped_args:]
|
||||
|
||||
def fw_with_masks(*args):
|
||||
fw_out = f(*args)
|
||||
return fw_out, [True if isinstance(ret, torch.Tensor) and ret.requires_grad else False for ret in fw_out]
|
||||
|
||||
joint = create_joint(fw_with_masks, aot_config=dummy_aot_config)
|
||||
_, grads = joint(list(mapped_input) + list(args),
|
||||
[grad for grad in mapped_grads if grad is not None and grad.requires_grad])
|
||||
|
||||
# In order to keep map functional for backward graph,
|
||||
# we clone outputs that are aliasing inputs
|
||||
input_storage = {StorageWeakRef(arg._typed_storage()) for arg in example_args if isinstance(arg, torch.Tensor)}
|
||||
|
||||
def maybe_clone(t):
|
||||
if isinstance(t, torch.Tensor) and StorageWeakRef(t._typed_storage()) in input_storage:
|
||||
return t.clone()
|
||||
return t
|
||||
return pytree.tree_map(maybe_clone, grads)
|
||||
|
||||
joint_num_mapped = len(example_grad) + len(example_xs)
|
||||
joint_graph = make_fx(joint_f)(*example_xs, *example_grad, *example_pos_args)
|
||||
return fw_graph, joint_graph
|
||||
|
||||
|
||||
def map_wrapper(f, xs, *args):
|
||||
flat_xs, xs_spec = pytree.tree_flatten(xs)
|
||||
if not all(isinstance(t, torch.Tensor) for t in flat_xs):
|
||||
raise RuntimeError(f"Mapped xs can only consist of tensors. Got xs {flat_xs}.")
|
||||
|
||||
num_mapped_args = len(flat_xs)
|
||||
shapes = [xs.shape for xs in flat_xs]
|
||||
leading_dim_size = shapes[0][0]
|
||||
if leading_dim_size == 0:
|
||||
raise RuntimeError(
|
||||
"Leading dimensions of mapped xs cannot be 0.")
|
||||
|
||||
if any(cur_shape[0] != leading_dim_size for cur_shape in shapes):
|
||||
raise RuntimeError(
|
||||
f"Leading dimensions of mapped xs must be consistent. Got shapes {shapes}.")
|
||||
|
||||
out_spec = None
|
||||
|
||||
def flat_fn(*flat_args):
|
||||
xs = pytree.tree_unflatten(flat_args[:num_mapped_args], xs_spec)
|
||||
unflattened_out = f(xs, *flat_args[num_mapped_args:])
|
||||
flat_out, tmp_out_spec = pytree.tree_flatten(unflattened_out)
|
||||
|
||||
nonlocal out_spec
|
||||
out_spec = tmp_out_spec
|
||||
return flat_out
|
||||
return pytree.tree_unflatten(map_impl(flat_fn, num_mapped_args, *flat_xs, *args), out_spec)
|
||||
|
||||
class MapAutogradOp(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, fw_graph, joint_graph, num_mapped_args, *flat_args):
|
||||
ctx.save_for_backward(*flat_args)
|
||||
ctx._joint_graph = joint_graph
|
||||
ctx._num_mapped_args = num_mapped_args
|
||||
try:
|
||||
guard = torch._C._AutoDispatchBelowAutograd()
|
||||
return (*map_impl(fw_graph, num_mapped_args, *flat_args), )
|
||||
finally:
|
||||
del guard
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *flat_grads):
|
||||
fw_args = ctx.saved_tensors
|
||||
fw_mapped_args = fw_args[:ctx._num_mapped_args]
|
||||
pos_args = fw_args[ctx._num_mapped_args:]
|
||||
|
||||
grads = map_impl(ctx._joint_graph, ctx._num_mapped_args + len(flat_grads), *fw_mapped_args, *flat_grads, *pos_args)
|
||||
return None, None, None, *grads
|
||||
|
||||
def trace_map(proxy_mode, func_overload, f, num_mapped, *args):
|
||||
xs = list(args[:num_mapped])
|
||||
pos_args = list(args[num_mapped:])
|
||||
leading_dim_size = xs[0].shape[0]
|
||||
|
||||
example_input = _unstack_pytree(xs)[0]
|
||||
body_graph = f
|
||||
if not isinstance(body_graph, torch.fx.GraphModule):
|
||||
body_graph = make_fx(body_graph)(*example_input, *pos_args)
|
||||
|
||||
with disable_proxy_modes_tracing():
|
||||
body_graph = make_fx(f)(xs[0], *args)
|
||||
example_outs = body_graph(*example_input, *pos_args)
|
||||
|
||||
def expand_tensor(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
return t.expand(leading_dim_size, *t.shape)
|
||||
return t
|
||||
expanded_outs = pytree.tree_map(expand_tensor, example_outs)
|
||||
|
||||
next_name = None
|
||||
i = 0
|
||||
@ -45,113 +186,145 @@ def trace_map(proxy_mode, func_overload, f, xs, *args):
|
||||
next_name = candidate
|
||||
|
||||
proxy_mode.tracer.root.register_module(next_name, body_graph)
|
||||
node_args = (body_graph, xs, *args)
|
||||
node_args = (body_graph, num_mapped, *args)
|
||||
proxy_args = pytree.tree_map(partial(unwrap_proxy, proxy_mode), node_args)
|
||||
out_proxy = proxy_mode.tracer.create_proxy('call_function', func_overload, proxy_args, {},
|
||||
name="map")
|
||||
outs = [body_graph(x, *args) for x in xs]
|
||||
# Implementation notes: we need to use new_empty() + copy_() here instead of stack() directly
|
||||
# because stack([...]) takes a fixed size list which will specialize dynamic shape here.
|
||||
# Meanwhile we want to preserve the looped over dimension as symbolic shape, such that:
|
||||
# ys: Tensor[s0, ...] = map(xs: Tensor[s0, ...], *args)
|
||||
out = outs[0].new_empty([xs.shape[0], *outs[0].shape])
|
||||
out.copy_(torch.stack(outs))
|
||||
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
|
||||
name="map_impl")
|
||||
return track_tensor_tree(expanded_outs, out_proxy, constant=None, tracer=proxy_mode.tracer)
|
||||
|
||||
def _unstack_pytree(xs):
|
||||
flat_xs, inspec = pytree.tree_flatten(xs)
|
||||
if not all(isinstance(xs, torch.Tensor) for xs in flat_xs):
|
||||
raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}")
|
||||
|
||||
if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs):
|
||||
raise RuntimeError(f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}")
|
||||
|
||||
a = zip(*flat_xs)
|
||||
pytrees = []
|
||||
for tuple in a:
|
||||
pytrees.append(pytree.tree_unflatten(tuple, inspec))
|
||||
return pytrees
|
||||
|
||||
def _stack_pytree(pytrees):
|
||||
flat_out = []
|
||||
out_spec = None
|
||||
for pt in pytrees:
|
||||
flat_pt, out_spec = pytree.tree_flatten(pt)
|
||||
flat_out.append(flat_pt)
|
||||
b = zip(*flat_out)
|
||||
stacked_out = []
|
||||
for leaves in b:
|
||||
if all(isinstance(leaf, torch.Tensor) for leaf in leaves):
|
||||
stacked_out.append(torch.stack(leaves))
|
||||
elif all(leaf is None for leaf in leaves):
|
||||
# Backward graph can return None output when forward inputs doesn't require grad.
|
||||
# When we eagerly execute backward graph, we need to call _stack_pytree on its output,
|
||||
# therefore we need to deal with None output.
|
||||
stacked_out.append(None)
|
||||
else:
|
||||
raise RuntimeError(f"Cannot stack {leaves}.")
|
||||
return pytree.tree_unflatten(stacked_out, out_spec)
|
||||
|
||||
@map_impl.py_impl(DispatchKey.CompositeExplicitAutograd)
|
||||
def map_dense(f, num_mapped_args, *args):
|
||||
xs = args[:num_mapped_args]
|
||||
pos_args = args[num_mapped_args:]
|
||||
pytrees = []
|
||||
for inp in _unstack_pytree(xs):
|
||||
pytrees.append(f(*inp, *pos_args))
|
||||
return _stack_pytree(pytrees)
|
||||
|
||||
|
||||
@map.py_impl(DispatchKey.CompositeExplicitAutograd)
|
||||
def map_cpu(f, xs, *args):
|
||||
mode = _get_current_dispatch_mode()
|
||||
assert (mode is None), "Mode should never be enabled for CPU/CUDA key"
|
||||
return torch.stack([f(x, *args) for x in xs])
|
||||
@map_impl.py_impl(DispatchKey.Autograd)
|
||||
def map_autograd(f, num_mapped_args, *args):
|
||||
fw_graph, bw_graph = create_fw_bw_graph(f, num_mapped_args, *args)
|
||||
flat_out = MapAutogradOp.apply(fw_graph, bw_graph, num_mapped_args, *args)
|
||||
return flat_out
|
||||
|
||||
|
||||
@map.py_impl(DispatchKey.Autograd)
|
||||
def map_autograd(f, xs, *args):
|
||||
# TODO: support autograd
|
||||
flat_operands, _ = tree_flatten([f, xs, args])
|
||||
assert all(not f.requires_grad for f in flat_operands
|
||||
if isinstance(f, torch.Tensor))
|
||||
|
||||
_ = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.AutogradCPU))
|
||||
return map(f, xs, *args)
|
||||
|
||||
|
||||
@map.py_impl(ProxyTorchDispatchMode)
|
||||
def map_proxy_torch_dispatch_mode(f, xs, *args):
|
||||
@map_impl.py_impl(ProxyTorchDispatchMode)
|
||||
def map_proxy_torch_dispatch_mode(f, num_mapped, *args):
|
||||
mode = _get_current_dispatch_mode()
|
||||
assert (mode is not None), "Mode should always be enabled for python fallback key"
|
||||
with _pop_mode_temporarily() as mode:
|
||||
res = trace_map(mode, map, f, xs, *args)
|
||||
return res
|
||||
if mode.enable_tracing:
|
||||
return trace_map(mode, map_impl, f, num_mapped, *args)
|
||||
else:
|
||||
return map_impl(f, num_mapped, *args)
|
||||
|
||||
|
||||
@map.py_impl(FakeTensorMode)
|
||||
def map_fake_tensor_mode(f, xs, *args):
|
||||
outs = [f(x, *args) for x in xs]
|
||||
return outs[0].new_empty([xs.shape[0], *outs[0].shape])
|
||||
@map_impl.py_impl(FakeTensorMode)
|
||||
def map_fake_tensor_mode(f, num_mapped, *args):
|
||||
return map_dense(f, num_mapped, *args)
|
||||
|
||||
|
||||
@map.py_impl(DispatchKey.Functionalize)
|
||||
def map_func(f, xs, *args):
|
||||
@map_impl.py_impl(DispatchKey.Functionalize)
|
||||
def map_func(f, num_mapped, *args):
|
||||
reapply_views = torch._C._functionalization_reapply_views_tls()
|
||||
xs = args[:num_mapped]
|
||||
pos_args = args[num_mapped:]
|
||||
unwrapped_xs = _unwrap_all_tensors_from_functional(xs, reapply_views=reapply_views)
|
||||
unwrapped_args = _unwrap_all_tensors_from_functional(args, reapply_views=reapply_views)
|
||||
unwrapped_args = _unwrap_all_tensors_from_functional(pos_args, reapply_views=reapply_views)
|
||||
mode = 'mutations_and_views' if reapply_views else 'mutations'
|
||||
|
||||
guard = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize))
|
||||
try:
|
||||
functional_map_fn = functionalize(f, remove=mode)
|
||||
inputs = (unwrapped_xs,) + unwrapped_args
|
||||
with disable_proxy_modes_tracing():
|
||||
example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args)
|
||||
|
||||
if _has_potential_branch_input_mutation(f, inputs):
|
||||
if _has_potential_branch_input_mutation(f, example_inputs):
|
||||
raise UnsupportedAliasMutationException(
|
||||
"torch.map is mutating the input!"
|
||||
)
|
||||
|
||||
if _has_potential_branch_input_alias(f, inputs):
|
||||
if _has_potential_branch_input_alias(f, example_inputs):
|
||||
raise UnsupportedAliasMutationException(
|
||||
"torch.map is aliasing the input!"
|
||||
)
|
||||
|
||||
map_return = map(functional_map_fn, unwrapped_xs, *unwrapped_args)
|
||||
map_return = map_impl(functional_map_fn, num_mapped, *unwrapped_xs, *unwrapped_args)
|
||||
return _wrap_all_tensors_to_functional(map_return, level=0)
|
||||
finally:
|
||||
del guard
|
||||
|
||||
@map.py_impl(torch._C._functorch.TransformType.Functionalize)
|
||||
def map_functionalize(interpreter, f, xs, *args):
|
||||
@map_impl.py_impl(torch._C._functorch.TransformType.Functionalize)
|
||||
def map_functionalize(interpreter, f, num_mapped, *args):
|
||||
"""
|
||||
Functionalization implementation for torch.map. Currently:
|
||||
1. We don't allow any input mutation inside the map function
|
||||
2. Our check for above condition is not exhaustive
|
||||
"""
|
||||
xs = args[:num_mapped]
|
||||
pos_args = args[num_mapped:]
|
||||
reapply_views = interpreter.functionalize_add_back_views()
|
||||
mode = 'mutations_and_views' if reapply_views else 'mutations'
|
||||
# At this point, we will see functionalized tensors, so need to unwrap them first
|
||||
unwrapped_xs = _unwrap_all_tensors_from_functional(xs, reapply_views=reapply_views)
|
||||
unwrapped_args = _unwrap_all_tensors_from_functional(args, reapply_views=reapply_views)
|
||||
unwrapped_args = _unwrap_all_tensors_from_functional(pos_args, reapply_views=reapply_views)
|
||||
|
||||
functional_map_fn = functionalize(f, remove=mode)
|
||||
|
||||
with interpreter.lower():
|
||||
inputs = (unwrapped_xs,) + unwrapped_args
|
||||
if _has_potential_branch_input_mutation(functional_map_fn, inputs):
|
||||
with disable_proxy_modes_tracing():
|
||||
example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args)
|
||||
if _has_potential_branch_input_mutation(f, example_inputs):
|
||||
raise UnsupportedAliasMutationException(
|
||||
"torch.map is mutating the input!"
|
||||
)
|
||||
|
||||
if _has_potential_branch_input_alias(functional_map_fn, inputs):
|
||||
if _has_potential_branch_input_alias(f, example_inputs):
|
||||
raise UnsupportedAliasMutationException(
|
||||
"torch.map is aliasing the input!"
|
||||
)
|
||||
|
||||
map_return = map(functional_map_fn, unwrapped_xs, *unwrapped_args)
|
||||
map_return = map_impl(functional_map_fn, num_mapped, *unwrapped_xs, *unwrapped_args)
|
||||
return _wrap_all_tensors_to_functional(map_return, level=interpreter.level())
|
||||
|
||||
# TODO(voz) Make this automatic for keys, this is very ugly atm
|
||||
map.fallthrough(DispatchKey.PythonDispatcher)
|
||||
map.fallthrough(DispatchKey.PythonTLSSnapshot)
|
||||
map.fallthrough(DispatchKey.ADInplaceOrView)
|
||||
map.fallthrough(DispatchKey.BackendSelect)
|
||||
map.fallthrough(DispatchKey.AutocastCPU)
|
||||
map_impl.fallthrough(DispatchKey.PythonDispatcher)
|
||||
map_impl.fallthrough(DispatchKey.PythonTLSSnapshot)
|
||||
map_impl.fallthrough(DispatchKey.ADInplaceOrView)
|
||||
map_impl.fallthrough(DispatchKey.BackendSelect)
|
||||
map_impl.fallthrough(DispatchKey.AutocastCPU)
|
||||
|
@ -27,6 +27,7 @@ from torch.nn.utils.rnn import PackedSequence
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests, toleranceOverride, tol
|
||||
from torch.testing._internal.common_methods_invocations import op_db, wrapper_set_seed
|
||||
from torch.testing._internal.common_modules import module_db, modules
|
||||
from torch.testing._internal.control_flow_opinfo_db import control_flow_opinfo_db
|
||||
from functorch import (
|
||||
grad, vjp, vmap, jacrev,
|
||||
make_fx
|
||||
@ -2975,12 +2976,12 @@ def _test_aot_autograd_module_helper(self, device, dtype, training, module_info,
|
||||
|
||||
|
||||
class TestEagerFusionOpInfo(AOTTestCase):
|
||||
@ops(op_db, allowed_dtypes=(torch.float,))
|
||||
@ops(op_db + control_flow_opinfo_db, allowed_dtypes=(torch.float,))
|
||||
@skipOps('TestEagerFusionOpInfo', 'test_aot_autograd_exhaustive', aot_autograd_failures)
|
||||
def test_aot_autograd_exhaustive(self, device, dtype, op):
|
||||
_test_aot_autograd_helper(self, device, dtype, op)
|
||||
|
||||
@ops(op_db, allowed_dtypes=(torch.float,))
|
||||
@ops(op_db + control_flow_opinfo_db, allowed_dtypes=(torch.float,))
|
||||
@patch("functorch.compile.config.debug_assert", True)
|
||||
@skipOps('TestEagerFusionOpInfo', 'test_aot_autograd_symbolic_exhaustive',
|
||||
aot_autograd_failures | symbolic_aot_autograd_failures)
|
||||
|
@ -12,6 +12,15 @@ from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
from torch._dynamo.exc import CondOpArgsMismatchError
|
||||
|
||||
def _fake_map(f, x, *args):
|
||||
from functorch.experimental._map import _stack_pytree, _unstack_pytree
|
||||
x_pytrees = _unstack_pytree(x)
|
||||
zs = []
|
||||
for xp in x_pytrees:
|
||||
zs.append(f(xp, *args))
|
||||
return _stack_pytree(zs)
|
||||
|
||||
|
||||
class TestControlFlow(TestCase):
|
||||
def test_cond_no_trace(self):
|
||||
def true_fn(x):
|
||||
@ -34,7 +43,7 @@ class TestControlFlow(TestCase):
|
||||
|
||||
x = torch.randn(4, device="cuda")
|
||||
pred = torch.tensor(False, device="cuda")
|
||||
result = cond(False, true_fn, false_fn, [x])
|
||||
result = cond(pred, true_fn, false_fn, [x])
|
||||
self.assertEqual(result, torch.cos(x))
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
|
||||
@ -45,8 +54,138 @@ class TestControlFlow(TestCase):
|
||||
xs = torch.ones(3, 2, 2, device="cuda")
|
||||
y = torch.ones(2, device="cuda")
|
||||
res = control_flow.map(f, xs, y)
|
||||
expected = _fake_map(f, xs, y)
|
||||
self.assertEqual(expected, res)
|
||||
|
||||
self.assertEqual(res, control_flow.map(f, torch.ones(3, 2, 2), torch.ones(2)))
|
||||
def test_map_illegal_inputs(self):
|
||||
def f(x, y):
|
||||
return x[0] + x[1] + y
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
r"Mapped xs can only consist of tensors\. Got xs \[3, tensor\(\[1\., 1\.\]\)\]\."):
|
||||
_ = control_flow.map(f, (3, torch.ones(2)), torch.ones(2))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
r"Leading dimensions of mapped xs cannot be 0\."):
|
||||
_ = control_flow.map(f, (torch.ones(0, 1, 2), torch.ones(0, 1, 2)), torch.ones(2))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
r"Leading dimensions of mapped xs must be consistent\. "
|
||||
r"Got shapes \[torch\.Size\(\[3, 4, 5\]\), torch\.Size\(\[4, 4, 5\]\)\]\."):
|
||||
_ = control_flow.map(f, (torch.ones(3, 4, 5), torch.ones(4, 4, 5)), torch.ones(5))
|
||||
|
||||
def test_map_illegal_outputs(self):
|
||||
def f(x, y):
|
||||
return x.item()
|
||||
|
||||
def f1(x, y):
|
||||
return y.size()
|
||||
|
||||
def f2(x, y):
|
||||
return None
|
||||
|
||||
x = torch.ones([3])
|
||||
y = torch.ones([1, 2, 3])
|
||||
with self.assertRaisesRegex(RuntimeError, r"Expect outputs of map only contains tensors or None\."):
|
||||
_ = control_flow.map(f, x, y)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"Expect outputs of map only contains tensors or None\."):
|
||||
out = control_flow.map(f1, x, y)
|
||||
|
||||
# return None is OK
|
||||
_ = control_flow.map(f2, x, y)
|
||||
|
||||
|
||||
def test_map_list_in_out(self):
|
||||
def f(x, y):
|
||||
return [[x[0][0] + y]]
|
||||
|
||||
xs = [[torch.ones(3, 2, 2)]]
|
||||
y = torch.ones(2)
|
||||
res = control_flow.map(f, xs, y)
|
||||
expected = _fake_map(f, xs, y)
|
||||
self.assertEqual(len(res), 1)
|
||||
self.assertEqual(len(res[0]), 1)
|
||||
self.assertEqual(expected, res)
|
||||
|
||||
def test_map_dict_in_out(self):
|
||||
def f(x, y):
|
||||
return {"c": x["a"]["b"] + y}
|
||||
|
||||
xs = {"a": {"b": torch.ones(3, 2, 2)}}
|
||||
y = torch.ones(2)
|
||||
res = control_flow.map(f, xs, y)
|
||||
expected = _fake_map(f, xs, y)
|
||||
self.assertEqual(len(res), 1)
|
||||
self.assertTrue("c" in res)
|
||||
self.assertEqual(expected, res)
|
||||
|
||||
def test_map_autograd_simple(self):
|
||||
def f(x, y):
|
||||
return x.sin().cos() * y.cos().sin()
|
||||
|
||||
xs = torch.ones(3, 2, 2, requires_grad=True)
|
||||
y = torch.ones(2, requires_grad=True)
|
||||
res = control_flow.map(f, xs, y)
|
||||
expected_res = _fake_map(f, xs, y)
|
||||
grad_out = torch.ones_like(res)
|
||||
grads = torch.autograd.grad(res, (xs, y), grad_out)
|
||||
expected_grads = torch.autograd.grad(expected_res, (xs, y), grad_out)
|
||||
self.assertEqual(expected_res, res)
|
||||
self.assertEqual(expected_grads, grads)
|
||||
|
||||
def test_map_autograd_simple_partial_grad(self):
|
||||
def f(x, y):
|
||||
return x.sin().cos() * y.cos().sin()
|
||||
|
||||
xs = torch.ones(3, 2, 2, requires_grad=True)
|
||||
# Disable the gradient computation for y
|
||||
y = torch.ones(2, requires_grad=False)
|
||||
res = control_flow.map(f, xs, y)
|
||||
expected_res = _fake_map(f, xs, y)
|
||||
grad_out = torch.ones_like(res)
|
||||
grads = torch.autograd.grad(res, (xs,), grad_out)
|
||||
expected_grads = torch.autograd.grad(expected_res, (xs,), grad_out)
|
||||
self.assertEqual(expected_res, res)
|
||||
self.assertEqual(expected_grads, grads)
|
||||
|
||||
def test_map_autograd_no_grad_output(self):
|
||||
def f(x, y):
|
||||
return x[0].sin().cos() + y, y.cos().sin()
|
||||
|
||||
xs = [torch.ones(3, 2, 2, requires_grad=True), torch.ones(3, 3)]
|
||||
# Disable the gradient computation for y
|
||||
y = torch.ones(2, requires_grad=False)
|
||||
res = control_flow.map(f, xs, y)
|
||||
expected_res = _fake_map(f, xs, y)
|
||||
grad_out = torch.ones_like(res[0])
|
||||
grads = torch.autograd.grad(res[0], (xs[0],), grad_out)
|
||||
expected_grads = torch.autograd.grad(expected_res[0], (xs[0],), grad_out)
|
||||
self.assertEqual(expected_res, res)
|
||||
self.assertEqual(expected_grads, grads)
|
||||
|
||||
|
||||
def test_map_autograd_nested_list(self):
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
def f(x, y):
|
||||
a, b = x
|
||||
c, d = a
|
||||
return [[b.sin() * c.cos()], d.sin() * y.cos()]
|
||||
|
||||
def fwbw(map_op, f, x, y):
|
||||
z = map_op(f, x, y)
|
||||
flat_x, _ = pytree.tree_flatten(x)
|
||||
flat_z, _ = pytree.tree_flatten(z)
|
||||
grads = torch.autograd.grad(flat_z, flat_x, [torch.ones_like(z) for z in flat_z])
|
||||
return z, grads
|
||||
|
||||
x = [[torch.randn(3, 2, 2, requires_grad=True), torch.randn(3, 2, 1, requires_grad=True)],
|
||||
torch.ones(3, 1, 2, requires_grad=True)]
|
||||
y = torch.ones(1, requires_grad=True)
|
||||
true_outs = fwbw(control_flow.map, f, x, y)
|
||||
fake_outs = fwbw(_fake_map, f, x, y)
|
||||
self.assertEqual(true_outs, fake_outs)
|
||||
|
||||
|
||||
class TestControlFlowTraced(TestCase):
|
||||
@ -702,17 +841,15 @@ class TestControlFlowTraced(TestCase):
|
||||
):
|
||||
make_fx(f, tracing_mode="fake")(x, torch.tensor(False))
|
||||
|
||||
def check_map_graph(self, gm, key):
|
||||
def check_map_count(self, gm, op_count):
|
||||
i = 0
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "call_function" and node.target == torch.ops.map:
|
||||
i += 1
|
||||
self.assertEqual(
|
||||
node.meta[key].shape[0], node.args[1].meta[key].shape[0]
|
||||
)
|
||||
self.assertEqual(i, 1)
|
||||
for m in gm.modules():
|
||||
for node in m.graph.nodes:
|
||||
if node.op == "call_function" and node.target == torch.ops.map_impl:
|
||||
i += 1
|
||||
self.assertEqual(i, op_count)
|
||||
|
||||
def test_map_real(self):
|
||||
def test_tracing_map_real(self):
|
||||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
@ -724,9 +861,9 @@ class TestControlFlowTraced(TestCase):
|
||||
y = torch.randn(2)
|
||||
res = gm(x, y)
|
||||
self.assertEqual(res, g(x, y))
|
||||
self.check_map_graph(gm, "tensor_meta")
|
||||
self.check_map_count(gm, 1)
|
||||
|
||||
def test_map_symbolic(self):
|
||||
def test_tracing_map_symbolic_simple(self):
|
||||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
@ -738,7 +875,140 @@ class TestControlFlowTraced(TestCase):
|
||||
y = torch.randn(2)
|
||||
res = gm(x, y)
|
||||
self.assertEqual(res, g(x, y))
|
||||
self.check_map_graph(gm, "val")
|
||||
self.check_map_count(gm, 1)
|
||||
|
||||
def test_tracing_map_symbolic_list(self):
|
||||
def f(x, y):
|
||||
return [x[0][0] + y, x[1] * y]
|
||||
|
||||
def g(xs, y, z):
|
||||
out = control_flow.map(f, xs, y)
|
||||
return out[0] + z, out[1] * z
|
||||
|
||||
example_x = [[torch.ones(3, 4, 5)], torch.ones(3, 4, 5)]
|
||||
gm = make_fx(g, tracing_mode="symbolic")(example_x, torch.ones(5), torch.ones(5))
|
||||
x = [[torch.randn(4, 5, 6)], torch.ones(4, 5, 6)]
|
||||
y = torch.randn(6)
|
||||
z = torch.ones(6)
|
||||
res = gm(x, y, z)
|
||||
self.assertEqual(res, g(x, y, z))
|
||||
self.check_map_count(gm, 1)
|
||||
|
||||
def test_tracing_map_symbolic_dict(self):
|
||||
def f(x, y):
|
||||
return {"d": x["b"]["a"] + y, "e": x["c"] * y}
|
||||
|
||||
def g(xs, y, z):
|
||||
out = control_flow.map(f, xs, y)
|
||||
return {"f": out["d"] + z, "g": out["e"] * z}
|
||||
|
||||
example_x = {"b": {"a": torch.ones(3, 4, 5)}, "c": torch.ones(3, 4, 5)}
|
||||
gm = make_fx(g, tracing_mode="symbolic")(example_x, torch.ones(5), torch.ones(5))
|
||||
x = {"b": {"a": torch.randn(4, 5, 6)}, "c": torch.ones(4, 5, 6)}
|
||||
y = torch.randn(6)
|
||||
z = torch.ones(6)
|
||||
res = gm(x, y, z)
|
||||
self.assertEqual(res, g(x, y, z))
|
||||
self.check_map_count(gm, 1)
|
||||
|
||||
def test_tracing_map_autograd_symbolic_simple(self):
|
||||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
def g(xs, y):
|
||||
out = control_flow.map(f, xs, y)
|
||||
return torch.autograd.grad(out, (xs, y), torch.ones_like(out))
|
||||
|
||||
gm = make_fx(g, tracing_mode="symbolic")(torch.ones(3, 4, 5, requires_grad=True), torch.ones(5, requires_grad=True))
|
||||
x = torch.randn(4, 5, 6, requires_grad=True)
|
||||
y = torch.randn(6, requires_grad=True)
|
||||
res = gm(x, y)
|
||||
self.assertEqual(res, g(x, y))
|
||||
self.check_map_count(gm, 2)
|
||||
|
||||
|
||||
def test_tracing_map_autograd_symbolic_list(self):
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
def f(x, y):
|
||||
return [x[0].cos() + y.sin(), x[1].sin() * y.cos()]
|
||||
|
||||
def g(xs, y):
|
||||
out = control_flow.map(f, xs, y)
|
||||
flat_out, _ = pytree.tree_flatten(out)
|
||||
flat_inp, _ = pytree.tree_flatten((xs, y))
|
||||
requires_grad_inp = [inp for inp in flat_inp if inp.requires_grad]
|
||||
return torch.autograd.grad(flat_out, requires_grad_inp, [torch.ones_like(out) for out in flat_out])
|
||||
|
||||
gm = make_fx(g, tracing_mode="symbolic")(
|
||||
[torch.ones(3, 4, 5), torch.ones(3, 4, 5, requires_grad=True)],
|
||||
torch.ones(5, requires_grad=True))
|
||||
x = [torch.randn(4, 5, 6), torch.ones(4, 5, 6, requires_grad=True)]
|
||||
y = torch.randn(6, requires_grad=True)
|
||||
res = gm(x, y)
|
||||
self.assertEqual(res, g(x, y))
|
||||
self.check_map_count(gm, 2)
|
||||
|
||||
def test_tracing_map_autograd_symbolic_dict(self):
|
||||
def f(x, y):
|
||||
return [x["a"] + y, x["b"] * y]
|
||||
|
||||
def g(xs, y):
|
||||
out = control_flow.map(f, xs, y)
|
||||
flat_out, _ = pytree.tree_flatten(out)
|
||||
flat_inp, _ = pytree.tree_flatten((xs, y))
|
||||
requires_grad_inp = [inp for inp in flat_inp if inp.requires_grad]
|
||||
return torch.autograd.grad(flat_out, requires_grad_inp, [torch.ones_like(out) for out in flat_out])
|
||||
|
||||
traced_x = {"a": torch.ones(3, 4, 5, requires_grad=True), "b": torch.ones(3, 4, 5, requires_grad=True)}
|
||||
gm = make_fx(g, tracing_mode="symbolic")(traced_x, torch.ones(5, requires_grad=True))
|
||||
x = {"a": torch.randn(4, 5, 6, requires_grad=True), "b": torch.ones(4, 5, 6, requires_grad=True)}
|
||||
y = torch.randn(6, requires_grad=True)
|
||||
res = gm(x, y)
|
||||
self.assertEqual(res, g(x, y))
|
||||
self.check_map_count(gm, 2)
|
||||
|
||||
def test_tracing_map_autograd_aot_functionalized(self):
|
||||
def inner(x, y):
|
||||
z = x - 1
|
||||
z.add_(1)
|
||||
return z * y
|
||||
|
||||
def f(xs, y):
|
||||
res = control_flow.map(inner, xs, y)
|
||||
grads = torch.autograd.grad(res, (xs, y), torch.ones_like(res))
|
||||
return grads
|
||||
|
||||
def f_wrapper(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
torch._enable_functionalization(reapply_views=False)
|
||||
try:
|
||||
return pytree.tree_map(from_fun, func(*args, **kwargs))
|
||||
finally:
|
||||
torch._disable_functionalization()
|
||||
return wrapper
|
||||
|
||||
example_inputs = (torch.ones(3, 2, 4, requires_grad=True), torch.ones(2, 4, requires_grad=True))
|
||||
gm = make_fx(f, tracing_mode="symbolic")(*example_inputs)
|
||||
fgm = make_fx(f_wrapper(f), tracing_mode="symbolic")(*example_inputs)
|
||||
xs = torch.ones(3, 4, 5, requires_grad=True)
|
||||
y = torch.ones(4, 5, requires_grad=True)
|
||||
|
||||
self.assertEqual(gm(xs, y), f(xs, y))
|
||||
|
||||
def count_mutable(gm):
|
||||
c = 0
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
if node.target == torch.ops.map_impl:
|
||||
c += count_mutable(getattr(gm, str(node.args[0])))
|
||||
elif schema := getattr(node.target, "_schema", None):
|
||||
c += int(schema.is_mutable)
|
||||
return c
|
||||
self.assertEqual(count_mutable(fgm), 0)
|
||||
# One for forward, one for recompuation logic in backward
|
||||
self.assertEqual(count_mutable(gm), 2)
|
||||
|
||||
def test_map_functionalized(self):
|
||||
def map_fn(x, y):
|
||||
@ -762,6 +1032,7 @@ class TestControlFlowTraced(TestCase):
|
||||
for node in gm.body_graph_0.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
self.assertTrue(not node.target._schema.is_mutable)
|
||||
self.check_map_count(gm, 1)
|
||||
|
||||
def test_map_functionalized_aot_func(self):
|
||||
def map_fn(x, y):
|
||||
@ -848,11 +1119,11 @@ class TestControlFlowTraced(TestCase):
|
||||
torch.tensor(True), torch.ones(3, 2, 4), torch.ones(4)
|
||||
)
|
||||
pred = torch.tensor(False)
|
||||
x = torch.randn(3, 2, 2)
|
||||
y = torch.randn(2)
|
||||
x = torch.randn(3, 2, 4)
|
||||
y = torch.randn(4)
|
||||
res = gm(pred, x, y)
|
||||
self.assertEqual(res, g(pred, x, y))
|
||||
self.check_map_graph(gm, "tensor_meta")
|
||||
self.check_map_count(gm, 1)
|
||||
|
||||
def test_nested_map_cond_symbolic(self):
|
||||
def true_fn(x, y):
|
||||
@ -875,7 +1146,7 @@ class TestControlFlowTraced(TestCase):
|
||||
y = torch.randn(2)
|
||||
res = gm(pred, x, y)
|
||||
self.assertEqual(res, g(pred, x, y))
|
||||
self.check_map_graph(gm, "val")
|
||||
self.check_map_count(gm, 1)
|
||||
|
||||
def test_nested_cond_map_cond_symbolic(self):
|
||||
|
||||
@ -909,6 +1180,7 @@ class TestControlFlowTraced(TestCase):
|
||||
y = torch.randn(2)
|
||||
res = gm(p, pred, xs, y)
|
||||
self.assertEqual(res, main(p, pred, xs, y))
|
||||
self.check_map_count(gm, 2)
|
||||
|
||||
def test_cond_with_sym_pred(self):
|
||||
def true_fn(x):
|
||||
|
@ -5,6 +5,7 @@ import torch
|
||||
|
||||
from torch.testing._internal.common_utils import TestGradients, run_tests
|
||||
from torch.testing._internal.common_methods_invocations import op_db
|
||||
from torch.testing._internal.control_flow_opinfo_db import control_flow_opinfo_db
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(instantiate_device_type_tests, ops, OpDTypes)
|
||||
|
||||
@ -17,7 +18,7 @@ _gradcheck_ops = partial(ops, dtypes=OpDTypes.supported,
|
||||
|
||||
class TestBwdGradients(TestGradients):
|
||||
# Tests that gradients are computed correctly
|
||||
@_gradcheck_ops(op_db)
|
||||
@_gradcheck_ops(op_db + control_flow_opinfo_db)
|
||||
def test_fn_grad(self, device, dtype, op):
|
||||
# This is verified by test_dtypes in test_ops.py
|
||||
if dtype not in op.supported_backward_dtypes(torch.device(device).type):
|
||||
@ -51,7 +52,7 @@ class TestBwdGradients(TestGradients):
|
||||
self._grad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
|
||||
|
||||
# Test that gradients of gradients are computed correctly
|
||||
@_gradcheck_ops(op_db)
|
||||
@_gradcheck_ops(op_db + control_flow_opinfo_db)
|
||||
def test_fn_gradgrad(self, device, dtype, op):
|
||||
self._skip_helper(op, device, dtype)
|
||||
if not op.supports_gradgrad:
|
||||
|
@ -16,6 +16,7 @@ from torch.fx.experimental.symbolic_shapes import (
|
||||
constrain_range, guard_int, GuardOnDataDependentSymNode
|
||||
)
|
||||
from torch.testing._internal.custom_op_db import custom_op_db
|
||||
from torch.testing._internal.control_flow_opinfo_db import control_flow_opinfo_db
|
||||
from torch.testing._internal.common_device_type import ops
|
||||
from torch._C import _disabled_torch_function_impl
|
||||
from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter, get_isolated_graphmodule
|
||||
@ -1610,17 +1611,17 @@ def _test_make_fx_helper(self, device, dtype, op, tracing_mode, inplace=False):
|
||||
self.assertEqual(new_out, old_out)
|
||||
|
||||
class TestProxyTensorOpInfo(TestCase):
|
||||
@ops(op_db + custom_op_db, allowed_dtypes=(torch.float,))
|
||||
@ops(op_db + custom_op_db + control_flow_opinfo_db, allowed_dtypes=(torch.float,))
|
||||
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_exhaustive', make_fx_failures)
|
||||
def test_make_fx_exhaustive(self, device, dtype, op):
|
||||
_test_make_fx_helper(self, device, dtype, op, "real")
|
||||
|
||||
@ops(op_db + custom_op_db, allowed_dtypes=(torch.float,))
|
||||
@ops(op_db + custom_op_db + control_flow_opinfo_db, allowed_dtypes=(torch.float,))
|
||||
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_fake_exhaustive', make_fx_failures.union(fake_tensor_failures))
|
||||
def test_make_fx_fake_exhaustive(self, device, dtype, op):
|
||||
_test_make_fx_helper(self, device, dtype, op, "fake")
|
||||
|
||||
@ops(op_db + custom_op_db, allowed_dtypes=(torch.float,))
|
||||
@ops(op_db + custom_op_db + control_flow_opinfo_db, allowed_dtypes=(torch.float,))
|
||||
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive',
|
||||
make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | outplace_symbolic_tensor_failures)
|
||||
def test_make_fx_symbolic_exhaustive(self, device, dtype, op):
|
||||
|
75
torch/testing/_internal/control_flow_opinfo_db.py
Normal file
75
torch/testing/_internal/control_flow_opinfo_db.py
Normal file
@ -0,0 +1,75 @@
|
||||
import torch
|
||||
import functools
|
||||
from torch.testing import make_tensor
|
||||
from functorch.experimental.control_flow import map
|
||||
from torch.testing._internal.opinfo.core import (
|
||||
OpInfo,
|
||||
SampleInput,
|
||||
)
|
||||
from torch.testing._internal.common_dtype import all_types_and
|
||||
|
||||
def sample_inputs_map(opinfo, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = functools.partial(
|
||||
make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
yield SampleInput([make_arg(2, 2, 2, low=0.1, high=2), make_arg(2, 2, 2, low=0.1, high=2)],
|
||||
args=(make_arg(1, low=0.1, high=2), make_arg(1, low=0.1, high=2)))
|
||||
|
||||
def inner_f(x, y0, y1):
|
||||
return [x[0].cos().add_(1.) * y0, (x[1] + y1.sin()).cos_().view(x[1].size())]
|
||||
|
||||
def simple_map(xs, y0, y1):
|
||||
def f(x, y0, y1):
|
||||
return inner_f(x, y0, y1)
|
||||
return map(f, xs, y0, y1)
|
||||
|
||||
def nested_map(xs, y0, y1):
|
||||
def f1(xx, y0, y1):
|
||||
def f2(x, y0, y1):
|
||||
return inner_f(x, y0, y1)
|
||||
return map(f2, xx, y0, y1)
|
||||
return map(f1, xs, y0, y1)
|
||||
|
||||
def triple_nested_map(xs, y0, y1):
|
||||
def f0(xs, y0, y1):
|
||||
def f1(xx, y0, y1):
|
||||
def f2(x, y0, y1):
|
||||
return inner_f(x, y0, y1)
|
||||
return map(f2, xx, y0, y1)
|
||||
return map(f1, xs, y0, y1)
|
||||
return map(f0, xs, y0, y1)
|
||||
|
||||
control_flow_opinfo_db = [
|
||||
OpInfo(
|
||||
"MapControlflowOp",
|
||||
op=simple_map,
|
||||
sample_inputs_func=sample_inputs_map,
|
||||
dtypes=all_types_and(torch.bool, torch.half),
|
||||
supports_out=False,
|
||||
check_batched_grad=False,
|
||||
check_batched_gradgrad=False,
|
||||
check_batched_forward_grad=False,
|
||||
check_inplace_batched_forward_grad=False,
|
||||
),
|
||||
OpInfo(
|
||||
"NestedMapControlflowOp",
|
||||
op=nested_map,
|
||||
sample_inputs_func=sample_inputs_map,
|
||||
dtypes=all_types_and(torch.bool, torch.half),
|
||||
supports_out=False,
|
||||
check_batched_grad=False,
|
||||
check_batched_gradgrad=False,
|
||||
check_batched_forward_grad=False,
|
||||
check_inplace_batched_forward_grad=False,
|
||||
),
|
||||
OpInfo(
|
||||
"TripleNestedMapControlflowOp",
|
||||
op=triple_nested_map,
|
||||
sample_inputs_func=sample_inputs_map,
|
||||
dtypes=all_types_and(torch.bool, torch.half),
|
||||
supports_out=False,
|
||||
check_batched_grad=False,
|
||||
check_batched_gradgrad=False,
|
||||
check_batched_forward_grad=False,
|
||||
check_inplace_batched_forward_grad=False,
|
||||
)
|
||||
]
|
Reference in New Issue
Block a user