Support map autograd and pytree in/out (#100494)

This PR adds autograd and pytree support for map operator.

Implementation-wise:

1. We temporarily make two HigherOrderOperators, "map" and "map_impl":
- "map" is user-facing. Currently, it unwraps the pytrees in inputs and create a flat_fn for it. Dynamo currently cannot deal with pytree.tree_flatten and pytree.tree_unflatten, we therefore make it a HigherOrderOperator to trigger dynamo logic of handling HigherOrderOperators.
- "map_impl" is the actual operator that works with the rest of torch subsystems such as functionalization, make_fx. It accepts flattend arguments, and a num_mapped_args integer denoting how many of the flattend arguments need to mapped i.e. their first dimension will be unstacked.

2. We create the forward and backward graph in autograd key and call torch.autograd.Function. Currently, the backward graph is recomputation-based and we need to partition the joint graph in the future to be more efficient.

Example traced graphs for map operators:
### Case 1: simple f and autograd
```python
def f(x, y):
    return x + y

def g(xs, y):
    out = control_flow.map(f, xs, y)
    return torch.autograd.grad(out, (xs, y), torch.ones_like(out))

gm = make_fx(g, tracing_mode="symbolic")(torch.ones(3, 4, 5, requires_grad=True), torch.ones(5, requires_grad=True))
# gm.print_readable() produces following:
class g(torch.nn.Module):
    def forward(self, xs_1: f32[3, s1, s2], y_1: f32[s2]):
        # No stacktrace found for following nodes
        body_graph_0 = self.body_graph_0
        map_impl = torch.ops.map_impl(body_graph_0, 1, xs_1, y_1);  body_graph_0 = None
        getitem: f32[3, s1, s2] = map_impl[0];  map_impl = None
        ones_like: f32[3, s1, s2] = torch.ops.aten.ones_like.default(getitem, pin_memory = False)
        is_same_size = torch.ops.aten.is_same_size.default(getitem, ones_like);  getitem = None
        body_graph_1 = self.body_graph_1
        map_impl_1 = torch.ops.map_impl(body_graph_1, 2, xs_1, ones_like, y_1);  body_graph_1 = xs_1 = ones_like = None
        getitem_1 = map_impl_1[0]
        getitem_2: f32[3, s1, s2] = map_impl_1[1]
        getitem_3: f32[3, s2] = map_impl_1[2];  map_impl_1 = None
        sum_1: f32[1, s2] = torch.ops.aten.sum.dim_IntList(getitem_3, [0], True);  getitem_3 = None
        sym_size: Sym(s2) = torch.ops.aten.sym_size(y_1, 0);  y_1 = None
        view: f32[s2] = torch.ops.aten.view.default(sum_1, [sym_size]);  sum_1 = sym_size = None
        return (getitem_2, view)

    class <lambda>(torch.nn.Module):
        def forward(self, arg0_1, arg1_1: f32[s1, s2], arg2_1: f32[s2]):
            # No stacktrace found for following nodes
            add: f32[s1, s2] = torch.ops.aten.add.Tensor(arg1_1, arg2_1);  arg1_1 = arg2_1 = None
            return [add]

    class <lambda>(torch.nn.Module):
        def forward(self, arg0_1, arg1_1: f32[s1, s2], arg2_1: f32[s1, s2], arg3_1: f32[s2]):
            # No stacktrace found for following nodes
            add: f32[s1, s2] = torch.ops.aten.add.Tensor(arg1_1, arg3_1);  arg1_1 = None
            is_same_size = torch.ops.aten.is_same_size.default(add, arg2_1);  add = None
            sum_1: f32[1, s2] = torch.ops.aten.sum.dim_IntList(arg2_1, [0], True)
            sym_size: Sym(s2) = torch.ops.aten.sym_size(arg3_1, 0);  arg3_1 = None
            view: f32[s2] = torch.ops.aten.view.default(sum_1, [sym_size]);  sum_1 = sym_size = None
            return [None, arg2_1, view]
```
### Case 2: list input/output f and autograd
```python
def f(x, y):
    return [x[0].cos() + y.sin(), x[1].sin() * y.cos()]

def g(xs, y):
    out = control_flow.map(f, xs, y)
    flat_out, _ = pytree.tree_flatten(out)
    flat_inp, _ = pytree.tree_flatten((xs, y))
    requires_grad_inp = [inp for inp in flat_inp if inp.requires_grad]
    return torch.autograd.grad(flat_out, requires_grad_inp, [torch.ones_like(out) for out in flat_out])

gm = make_fx(g, tracing_mode="symbolic")(
    [torch.ones(3, 4, 5), torch.ones(3, 4, 5, requires_grad=True)],
    torch.ones(5, requires_grad=True))

# gm.print_readable() produces following:
class g(torch.nn.Module):
    def forward(self, xs, y):
        xs_1: f32[3, s1, s2], xs_2: f32[3, s1, s2], y_1: f32[s2], = fx_pytree.tree_flatten_spec([xs, y], self._in_spec)
        # No stacktrace found for following nodes
        body_graph_0 = self.body_graph_0
        map_impl = torch.ops.map_impl(body_graph_0, 2, xs_1, xs_2, y_1);  body_graph_0 = None
        getitem: f32[3, s1, s2] = map_impl[0]
        getitem_1: f32[3, s1, s2] = map_impl[1];  map_impl = None
        ones_like: f32[3, s1, s2] = torch.ops.aten.ones_like.default(getitem, pin_memory = False)
        ones_like_1: f32[3, s1, s2] = torch.ops.aten.ones_like.default(getitem_1, pin_memory = False)
        is_same_size = torch.ops.aten.is_same_size.default(getitem, ones_like);  getitem = None
        is_same_size_1 = torch.ops.aten.is_same_size.default(getitem_1, ones_like_1);  getitem_1 = None
        body_graph_1 = self.body_graph_1
        map_impl_1 = torch.ops.map_impl(body_graph_1, 4, xs_1, xs_2, ones_like, ones_like_1, y_1);  body_graph_1 = xs_1 = xs_2 = ones_like = ones_like_1 = None
        getitem_2 = map_impl_1[0]
        getitem_3 = map_impl_1[1]
        getitem_4: f32[3, s1, s2] = map_impl_1[2]
        getitem_5: f32[3, s2] = map_impl_1[3];  map_impl_1 = None
        sum_1: f32[1, s2] = torch.ops.aten.sum.dim_IntList(getitem_5, [0], True);  getitem_5 = None
        sym_size: Sym(s2) = torch.ops.aten.sym_size(y_1, 0);  y_1 = None
        view: f32[s2] = torch.ops.aten.view.default(sum_1, [sym_size]);  sum_1 = sym_size = None
        return pytree.tree_unflatten([getitem_4, view], self._out_spec)

    class <lambda>(torch.nn.Module):
        def forward(self, arg0_1, arg1_1: f32[s1, s2], arg2_1: f32[s1, s2], arg3_1: f32[s2]):
            # No stacktrace found for following nodes
            cos: f32[s1, s2] = torch.ops.aten.cos.default(arg1_1);  arg1_1 = None
            sin: f32[s2] = torch.ops.aten.sin.default(arg3_1)
            add: f32[s1, s2] = torch.ops.aten.add.Tensor(cos, sin);  cos = sin = None
            sin_1: f32[s1, s2] = torch.ops.aten.sin.default(arg2_1);  arg2_1 = None
            cos_1: f32[s2] = torch.ops.aten.cos.default(arg3_1);  arg3_1 = None
            mul: f32[s1, s2] = torch.ops.aten.mul.Tensor(sin_1, cos_1);  sin_1 = cos_1 = None
            return [add, mul]

    class <lambda>(torch.nn.Module):
        def forward(self, arg0_1, arg1_1: f32[s1, s2], arg2_1: f32[s1, s2], arg3_1: f32[s1, s2], arg4_1: f32[s1, s2], arg5_1: f32[s2]):
            # No stacktrace found for following nodes
            cos: f32[s1, s2] = torch.ops.aten.cos.default(arg1_1);  arg1_1 = None
            sin: f32[s2] = torch.ops.aten.sin.default(arg5_1)
            add: f32[s1, s2] = torch.ops.aten.add.Tensor(cos, sin);  cos = sin = None
            sin_1: f32[s1, s2] = torch.ops.aten.sin.default(arg2_1)
            cos_1: f32[s2] = torch.ops.aten.cos.default(arg5_1)
            mul: f32[s1, s2] = torch.ops.aten.mul.Tensor(sin_1, cos_1)
            is_same_size = torch.ops.aten.is_same_size.default(add, arg3_1);  add = None
            is_same_size_1 = torch.ops.aten.is_same_size.default(mul, arg4_1);  mul = None
            mul_1: f32[s1, s2] = torch.ops.aten.mul.Tensor(arg4_1, sin_1);  sin_1 = None
            mul_2: f32[s1, s2] = torch.ops.aten.mul.Tensor(arg4_1, cos_1);  arg4_1 = cos_1 = None
            sum_1: f32[1, s2] = torch.ops.aten.sum.dim_IntList(mul_1, [0], True);  mul_1 = None
            sym_size: Sym(s2) = torch.ops.aten.sym_size(arg5_1, 0)
            view: f32[s2] = torch.ops.aten.view.default(sum_1, [sym_size]);  sum_1 = None

            #
            sin_2: f32[s2] = torch.ops.aten.sin.default(arg5_1)
            neg: f32[s2] = torch.ops.aten.neg.default(sin_2);  sin_2 = None
            mul_3: f32[s2] = torch.ops.aten.mul.Tensor(view, neg);  view = neg = None
            cos_2: f32[s1, s2] = torch.ops.aten.cos.default(arg2_1);  arg2_1 = None
            mul_4: f32[s1, s2] = torch.ops.aten.mul.Tensor(mul_2, cos_2);  mul_2 = cos_2 = None
            sum_2: f32[1, s2] = torch.ops.aten.sum.dim_IntList(arg3_1, [0], True);  arg3_1 = None
            view_1: f32[s2] = torch.ops.aten.view.default(sum_2, [sym_size]);  sum_2 = sym_size = None
            cos_3: f32[s2] = torch.ops.aten.cos.default(arg5_1);  arg5_1 = None
            mul_5: f32[s2] = torch.ops.aten.mul.Tensor(view_1, cos_3);  view_1 = cos_3 = None
            add_1: f32[s2] = torch.ops.aten.add.Tensor(mul_3, mul_5);  mul_3 = mul_5 = None
            return [None, None, mul_4, add_1]
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100494
Approved by: https://github.com/zou3519
This commit is contained in:
ydwu4
2023-05-16 22:05:11 +00:00
committed by PyTorch MergeBot
parent 552b712f80
commit b8fa41be9d
7 changed files with 617 additions and 95 deletions

View File

@ -137,8 +137,10 @@ def inner(pred, true_fn, false_fn, operands):
mode = _get_current_dispatch_mode()
assert (mode is not None), "Mode should always be enabled for python fallback key"
with _pop_mode_temporarily() as mode:
res = trace_cond(mode, cond, pred, true_fn, false_fn, operands)
return res
if mode.enable_tracing:
return trace_cond(mode, cond, pred, true_fn, false_fn, operands)
else:
return cond(pred, true_fn, false_fn, operands)
@cond.py_impl(FakeTensorMode)
@ -214,13 +216,19 @@ def _has_potential_branch_input_alias(branch, inputs):
def _detect_input_alias(gm):
input_storages = set()
for node in gm.graph.nodes:
if node.op == "placeholder":
# We need to check existence of "val" because we reuse the logic here
# for map operator, where num_mapped_args is a scalar
# and doesn't have a "val" meta.
if node.op == "placeholder" and "val" in node.meta:
input_storages.add(StorageWeakRef(node.meta['val']._typed_storage()))
if node.op == "output":
for out in node.args:
out_storage = StorageWeakRef(out.meta["val"]._typed_storage())
if out_storage in input_storages:
return True
def check_alias(out):
if out is not None and "val" in out.meta:
out_storage = StorageWeakRef(out.meta['val']._typed_storage())
return out_storage in input_storages
return False
if any(pytree.tree_flatten(pytree.tree_map(check_alias, node.args))[0]):
return True
for _, module in gm.named_children():
if isinstance(module, torch.fx.GraphModule) and _detect_input_alias(module):

View File

@ -4,8 +4,10 @@ import torch
import torch.utils._pytree as pytree
from torch._C import DispatchKey, DispatchKeySet, ExcludeDispatchKeyGuard
from torch._functorch.eager_transforms import _unwrap_all_tensors_from_functional, _wrap_all_tensors_to_functional, functionalize
from torch._functorch.aot_autograd import create_joint
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.multiprocessing.reductions import StorageWeakRef
from torch.fx.experimental.proxy_tensor import (
disable_proxy_modes_tracing,
make_fx,
@ -17,23 +19,153 @@ from torch.utils._python_dispatch import (
_get_current_dispatch_mode,
_pop_mode_temporarily,
)
from torch.utils._pytree import tree_flatten
from torch._dispatch.python import suspend_functionalization
from ._cond import _has_potential_branch_input_alias, _has_potential_branch_input_mutation, UnsupportedAliasMutationException
map = HigherOrderOperator("map")
# TODO: We add this to prevent dymamo from tracing into map_wrapper,
# remove the wrapper call when it's ready.
class MapWrapper(HigherOrderOperator):
def __call__(self, xs, *args):
return map_wrapper(xs, *args)
map = MapWrapper("map")
map_impl = HigherOrderOperator("map_impl")
def create_fw_bw_graph(f, num_mapped_args, *args):
mapped_xs = args[:num_mapped_args]
pos_args = args[num_mapped_args:]
# Note: We create "clean" environments for make_fx by suspending all dispatch keys
# between Autograd and Python key. Currently, we only suspend functionalization but more can be
# added when required. Will encounter two problems if we don't suspend functionalization:
#
# 1. make_fx fails to capture operations on input: the inputs are wrapped as _to_functional_tensor_wrapper,
# but they will be unwrapped before entering ProxyTorchDispatchMode as part of the dispatching.
# However, it's the outside wrapper that tracer creates proxies for. This casuses tracer fail to
# fetch the proxy for the inputs and fail to capture any operations on them.
#
# 2. make_fx fails to capture output: the outputs after ProxyTorchDispatchMode are further
# wrapped as FunctionalTensorWrapper in Functionalize key after return. However, the tracer
# only associates the inner tensor with proxy in ProxyTorchDispatchMode. Therefore,
# when creating the output node, it fails to associate the wrapped tensor with its proxy.
# Instead, it will create _tensor_constant as output.
with suspend_functionalization():
with disable_proxy_modes_tracing():
def from_fun(t):
if isinstance(t, torch.Tensor):
return torch.empty_strided(t.size(), t.stride(), requires_grad=t.requires_grad)
return t
example_xs = [from_fun(xs) for xs in _unstack_pytree(mapped_xs)[0]]
example_pos_args = [from_fun(arg) if isinstance(arg, torch.Tensor) else arg for arg in pos_args]
example_flat_out = pytree.tree_map(from_fun, f(*example_xs, *example_pos_args))
if any(not isinstance(out, torch.Tensor) for out in example_flat_out if out is not None):
raise RuntimeError("Expect outputs of map only contains tensors or None. "
f"Got types {[type(out) for out in example_flat_out]}.")
example_grad = [from_fun(out) for out in example_flat_out]
def trace_map(proxy_mode, func_overload, f, xs, *args):
if not isinstance(xs, torch.Tensor):
raise ValueError("map() must loop over a tensor")
if len(xs.shape) == 0 or xs.shape[0] == 0:
raise ValueError("map() cannot be traced with scalar tensors or zero dimension tensors")
if not all(isinstance(o, torch.Tensor) for o in args):
raise ValueError("map() operands must be a list of tensors or modules")
fw_graph = make_fx(f)(*example_xs, *example_pos_args)
def joint_f(*example_args):
joint_mapped_args = example_args[:joint_num_mapped]
args = example_args[joint_num_mapped:]
mapped_input = joint_mapped_args[:num_mapped_args]
mapped_grads = joint_mapped_args[num_mapped_args:]
def fw_with_masks(*args):
fw_out = f(*args)
return fw_out, [True if isinstance(ret, torch.Tensor) and ret.requires_grad else False for ret in fw_out]
joint = create_joint(fw_with_masks)
_, grads = joint(list(mapped_input) + list(args),
[grad for grad in mapped_grads if grad is not None and grad.requires_grad])
# In order to keep map functional for backward graph,
# we clone outputs that are aliasing inputs
input_storage = {StorageWeakRef(arg._typed_storage()) for arg in example_args if isinstance(arg, torch.Tensor)}
def maybe_clone(t):
if isinstance(t, torch.Tensor) and StorageWeakRef(t._typed_storage()) in input_storage:
return t.clone()
return t
return pytree.tree_map(maybe_clone, grads)
joint_num_mapped = len(example_grad) + len(example_xs)
joint_graph = make_fx(joint_f)(*example_xs, *example_grad, *example_pos_args)
return fw_graph, joint_graph
def map_wrapper(f, xs, *args):
flat_xs, xs_spec = pytree.tree_flatten(xs)
if not all(isinstance(t, torch.Tensor) for t in flat_xs):
raise RuntimeError(f"Mapped xs can only consist of tensors. Got xs {flat_xs}.")
num_mapped_args = len(flat_xs)
shapes = [xs.shape for xs in flat_xs]
leading_dim_size = shapes[0][0]
if leading_dim_size == 0:
raise RuntimeError(
"Leading dimensions of mapped xs cannot be 0.")
if any(cur_shape[0] != leading_dim_size for cur_shape in shapes):
raise RuntimeError(
f"Leading dimensions of mapped xs must be consistent. Got shapes {shapes}.")
out_spec = None
def flat_fn(*flat_args):
xs = pytree.tree_unflatten(flat_args[:num_mapped_args], xs_spec)
unflattened_out = f(xs, *flat_args[num_mapped_args:])
flat_out, tmp_out_spec = pytree.tree_flatten(unflattened_out)
nonlocal out_spec
out_spec = tmp_out_spec
return flat_out
return pytree.tree_unflatten(map_impl(flat_fn, num_mapped_args, *flat_xs, *args), out_spec)
class MapAutogradOp(torch.autograd.Function):
@staticmethod
def forward(ctx, fw_graph, joint_graph, num_mapped_args, *flat_args):
ctx.save_for_backward(*flat_args)
ctx._joint_graph = joint_graph
ctx._num_mapped_args = num_mapped_args
try:
guard = torch._C._AutoDispatchBelowAutograd()
return (*map_impl(fw_graph, num_mapped_args, *flat_args), )
finally:
del guard
@staticmethod
def backward(ctx, *flat_grads):
fw_args = ctx.saved_tensors
fw_mapped_args = fw_args[:ctx._num_mapped_args]
pos_args = fw_args[ctx._num_mapped_args:]
grads = map_impl(ctx._joint_graph, ctx._num_mapped_args + len(flat_grads), *fw_mapped_args, *flat_grads, *pos_args)
return None, None, None, *grads
def trace_map(proxy_mode, func_overload, f, num_mapped, *args):
xs = list(args[:num_mapped])
pos_args = list(args[num_mapped:])
leading_dim_size = xs[0].shape[0]
example_input = _unstack_pytree(xs)[0]
body_graph = f
if not isinstance(body_graph, torch.fx.GraphModule):
body_graph = make_fx(body_graph)(*example_input, *pos_args)
with disable_proxy_modes_tracing():
body_graph = make_fx(f)(xs[0], *args)
example_outs = body_graph(*example_input, *pos_args)
def expand_tensor(t):
if isinstance(t, torch.Tensor):
return t.expand(leading_dim_size, *t.shape)
return t
expanded_outs = pytree.tree_map(expand_tensor, example_outs)
next_name = None
i = 0
@ -45,113 +177,145 @@ def trace_map(proxy_mode, func_overload, f, xs, *args):
next_name = candidate
proxy_mode.tracer.root.register_module(next_name, body_graph)
node_args = (body_graph, xs, *args)
node_args = (body_graph, num_mapped, *args)
proxy_args = pytree.tree_map(partial(unwrap_proxy, proxy_mode), node_args)
out_proxy = proxy_mode.tracer.create_proxy('call_function', func_overload, proxy_args, {},
name="map")
outs = [body_graph(x, *args) for x in xs]
# Implementation notes: we need to use new_empty() + copy_() here instead of stack() directly
# because stack([...]) takes a fixed size list which will specialize dynamic shape here.
# Meanwhile we want to preserve the looped over dimension as symbolic shape, such that:
# ys: Tensor[s0, ...] = map(xs: Tensor[s0, ...], *args)
out = outs[0].new_empty([xs.shape[0], *outs[0].shape])
out.copy_(torch.stack(outs))
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
name="map_impl")
return track_tensor_tree(expanded_outs, out_proxy, constant=None, tracer=proxy_mode.tracer)
def _unstack_pytree(xs):
flat_xs, inspec = pytree.tree_flatten(xs)
if not all(isinstance(xs, torch.Tensor) for xs in flat_xs):
raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}")
if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs):
raise RuntimeError(f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}")
a = zip(*flat_xs)
pytrees = []
for tuple in a:
pytrees.append(pytree.tree_unflatten(tuple, inspec))
return pytrees
def _stack_pytree(pytrees):
flat_out = []
out_spec = None
for pt in pytrees:
flat_pt, out_spec = pytree.tree_flatten(pt)
flat_out.append(flat_pt)
b = zip(*flat_out)
stacked_out = []
for leaves in b:
if all(isinstance(leaf, torch.Tensor) for leaf in leaves):
stacked_out.append(torch.stack(leaves))
elif all(leaf is None for leaf in leaves):
# Backward graph can return None output when forward inputs doesn't require grad.
# When we eagerly execute backward graph, we need to call _stack_pytree on its output,
# therefore we need to deal with None output.
stacked_out.append(None)
else:
raise RuntimeError(f"Cannot stack {leaves}.")
return pytree.tree_unflatten(stacked_out, out_spec)
@map_impl.py_impl(DispatchKey.CompositeExplicitAutograd)
def map_dense(f, num_mapped_args, *args):
xs = args[:num_mapped_args]
pos_args = args[num_mapped_args:]
pytrees = []
for inp in _unstack_pytree(xs):
pytrees.append(f(*inp, *pos_args))
return _stack_pytree(pytrees)
@map.py_impl(DispatchKey.CompositeExplicitAutograd)
def map_cpu(f, xs, *args):
mode = _get_current_dispatch_mode()
assert (mode is None), "Mode should never be enabled for CPU/CUDA key"
return torch.stack([f(x, *args) for x in xs])
@map_impl.py_impl(DispatchKey.Autograd)
def map_autograd(f, num_mapped_args, *args):
fw_graph, bw_graph = create_fw_bw_graph(f, num_mapped_args, *args)
flat_out = MapAutogradOp.apply(fw_graph, bw_graph, num_mapped_args, *args)
return flat_out
@map.py_impl(DispatchKey.Autograd)
def map_autograd(f, xs, *args):
# TODO: support autograd
flat_operands, _ = tree_flatten([f, xs, args])
assert all(not f.requires_grad for f in flat_operands
if isinstance(f, torch.Tensor))
_ = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.AutogradCPU))
return map(f, xs, *args)
@map.py_impl(ProxyTorchDispatchMode)
def map_proxy_torch_dispatch_mode(f, xs, *args):
@map_impl.py_impl(ProxyTorchDispatchMode)
def map_proxy_torch_dispatch_mode(f, num_mapped, *args):
mode = _get_current_dispatch_mode()
assert (mode is not None), "Mode should always be enabled for python fallback key"
with _pop_mode_temporarily() as mode:
res = trace_map(mode, map, f, xs, *args)
return res
if mode.enable_tracing:
return trace_map(mode, map_impl, f, num_mapped, *args)
else:
return map_impl(f, num_mapped, *args)
@map.py_impl(FakeTensorMode)
def map_fake_tensor_mode(f, xs, *args):
outs = [f(x, *args) for x in xs]
return outs[0].new_empty([xs.shape[0], *outs[0].shape])
@map_impl.py_impl(FakeTensorMode)
def map_fake_tensor_mode(f, num_mapped, *args):
return map_dense(f, num_mapped, *args)
@map.py_impl(DispatchKey.Functionalize)
def map_func(f, xs, *args):
@map_impl.py_impl(DispatchKey.Functionalize)
def map_func(f, num_mapped, *args):
reapply_views = torch._C._functionalization_reapply_views_tls()
xs = args[:num_mapped]
pos_args = args[num_mapped:]
unwrapped_xs = _unwrap_all_tensors_from_functional(xs, reapply_views=reapply_views)
unwrapped_args = _unwrap_all_tensors_from_functional(args, reapply_views=reapply_views)
unwrapped_args = _unwrap_all_tensors_from_functional(pos_args, reapply_views=reapply_views)
mode = 'mutations_and_views' if reapply_views else 'mutations'
guard = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize))
try:
functional_map_fn = functionalize(f, remove=mode)
inputs = (unwrapped_xs,) + unwrapped_args
with disable_proxy_modes_tracing():
example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args)
if _has_potential_branch_input_mutation(f, inputs):
if _has_potential_branch_input_mutation(f, example_inputs):
raise UnsupportedAliasMutationException(
"torch.map is mutating the input!"
)
if _has_potential_branch_input_alias(f, inputs):
if _has_potential_branch_input_alias(f, example_inputs):
raise UnsupportedAliasMutationException(
"torch.map is aliasing the input!"
)
map_return = map(functional_map_fn, unwrapped_xs, *unwrapped_args)
map_return = map_impl(functional_map_fn, num_mapped, *unwrapped_xs, *unwrapped_args)
return _wrap_all_tensors_to_functional(map_return, level=0)
finally:
del guard
@map.py_impl(torch._C._functorch.TransformType.Functionalize)
def map_functionalize(interpreter, f, xs, *args):
@map_impl.py_impl(torch._C._functorch.TransformType.Functionalize)
def map_functionalize(interpreter, f, num_mapped, *args):
"""
Functionalization implementation for torch.map. Currently:
1. We don't allow any input mutation inside the map function
2. Our check for above condition is not exhaustive
"""
xs = args[:num_mapped]
pos_args = args[num_mapped:]
reapply_views = interpreter.functionalize_add_back_views()
mode = 'mutations_and_views' if reapply_views else 'mutations'
# At this point, we will see functionalized tensors, so need to unwrap them first
unwrapped_xs = _unwrap_all_tensors_from_functional(xs, reapply_views=reapply_views)
unwrapped_args = _unwrap_all_tensors_from_functional(args, reapply_views=reapply_views)
unwrapped_args = _unwrap_all_tensors_from_functional(pos_args, reapply_views=reapply_views)
functional_map_fn = functionalize(f, remove=mode)
with interpreter.lower():
inputs = (unwrapped_xs,) + unwrapped_args
if _has_potential_branch_input_mutation(functional_map_fn, inputs):
with disable_proxy_modes_tracing():
example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args)
if _has_potential_branch_input_mutation(f, example_inputs):
raise UnsupportedAliasMutationException(
"torch.map is mutating the input!"
)
if _has_potential_branch_input_alias(functional_map_fn, inputs):
if _has_potential_branch_input_alias(f, example_inputs):
raise UnsupportedAliasMutationException(
"torch.map is aliasing the input!"
)
map_return = map(functional_map_fn, unwrapped_xs, *unwrapped_args)
map_return = map_impl(functional_map_fn, num_mapped, *unwrapped_xs, *unwrapped_args)
return _wrap_all_tensors_to_functional(map_return, level=interpreter.level())
# TODO(voz) Make this automatic for keys, this is very ugly atm
map.fallthrough(DispatchKey.PythonDispatcher)
map.fallthrough(DispatchKey.PythonTLSSnapshot)
map.fallthrough(DispatchKey.ADInplaceOrView)
map.fallthrough(DispatchKey.BackendSelect)
map.fallthrough(DispatchKey.AutocastCPU)
map_impl.fallthrough(DispatchKey.PythonDispatcher)
map_impl.fallthrough(DispatchKey.PythonTLSSnapshot)
map_impl.fallthrough(DispatchKey.ADInplaceOrView)
map_impl.fallthrough(DispatchKey.BackendSelect)
map_impl.fallthrough(DispatchKey.AutocastCPU)

View File

@ -27,6 +27,7 @@ from torch.nn.utils.rnn import PackedSequence
from torch.testing._internal.common_device_type import instantiate_device_type_tests, toleranceOverride, tol
from torch.testing._internal.common_methods_invocations import op_db, wrapper_set_seed
from torch.testing._internal.common_modules import module_db, modules
from torch.testing._internal.control_flow_opinfo_db import control_flow_opinfo_db
from functorch import (
grad, vjp, vmap, jacrev,
make_fx
@ -2975,12 +2976,12 @@ def _test_aot_autograd_module_helper(self, device, dtype, training, module_info,
class TestEagerFusionOpInfo(AOTTestCase):
@ops(op_db, allowed_dtypes=(torch.float,))
@ops(op_db + control_flow_opinfo_db, allowed_dtypes=(torch.float,))
@skipOps('TestEagerFusionOpInfo', 'test_aot_autograd_exhaustive', aot_autograd_failures)
def test_aot_autograd_exhaustive(self, device, dtype, op):
_test_aot_autograd_helper(self, device, dtype, op)
@ops(op_db, allowed_dtypes=(torch.float,))
@ops(op_db + control_flow_opinfo_db, allowed_dtypes=(torch.float,))
@patch("functorch.compile.config.debug_assert", True)
@skipOps('TestEagerFusionOpInfo', 'test_aot_autograd_symbolic_exhaustive',
aot_autograd_failures | symbolic_aot_autograd_failures)

View File

@ -12,6 +12,15 @@ from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import run_tests, TestCase
from torch._dynamo.exc import CondOpArgsMismatchError
def _fake_map(f, x, *args):
from functorch.experimental._map import _stack_pytree, _unstack_pytree
x_pytrees = _unstack_pytree(x)
zs = []
for xp in x_pytrees:
zs.append(f(xp, *args))
return _stack_pytree(zs)
class TestControlFlow(TestCase):
def test_cond_no_trace(self):
def true_fn(x):
@ -34,7 +43,7 @@ class TestControlFlow(TestCase):
x = torch.randn(4, device="cuda")
pred = torch.tensor(False, device="cuda")
result = cond(False, true_fn, false_fn, [x])
result = cond(pred, true_fn, false_fn, [x])
self.assertEqual(result, torch.cos(x))
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
@ -45,8 +54,138 @@ class TestControlFlow(TestCase):
xs = torch.ones(3, 2, 2, device="cuda")
y = torch.ones(2, device="cuda")
res = control_flow.map(f, xs, y)
expected = _fake_map(f, xs, y)
self.assertEqual(expected, res)
self.assertEqual(res, control_flow.map(f, torch.ones(3, 2, 2), torch.ones(2)))
def test_map_illegal_inputs(self):
def f(x, y):
return x[0] + x[1] + y
with self.assertRaisesRegex(RuntimeError,
r"Mapped xs can only consist of tensors\. Got xs \[3, tensor\(\[1\., 1\.\]\)\]\."):
_ = control_flow.map(f, (3, torch.ones(2)), torch.ones(2))
with self.assertRaisesRegex(RuntimeError,
r"Leading dimensions of mapped xs cannot be 0\."):
_ = control_flow.map(f, (torch.ones(0, 1, 2), torch.ones(0, 1, 2)), torch.ones(2))
with self.assertRaisesRegex(RuntimeError,
r"Leading dimensions of mapped xs must be consistent\. "
r"Got shapes \[torch\.Size\(\[3, 4, 5\]\), torch\.Size\(\[4, 4, 5\]\)\]\."):
_ = control_flow.map(f, (torch.ones(3, 4, 5), torch.ones(4, 4, 5)), torch.ones(5))
def test_map_illegal_outputs(self):
def f(x, y):
return x.item()
def f1(x, y):
return y.size()
def f2(x, y):
return None
x = torch.ones([3])
y = torch.ones([1, 2, 3])
with self.assertRaisesRegex(RuntimeError, r"Expect outputs of map only contains tensors or None\."):
_ = control_flow.map(f, x, y)
with self.assertRaisesRegex(RuntimeError, r"Expect outputs of map only contains tensors or None\."):
out = control_flow.map(f1, x, y)
# return None is OK
_ = control_flow.map(f2, x, y)
def test_map_list_in_out(self):
def f(x, y):
return [[x[0][0] + y]]
xs = [[torch.ones(3, 2, 2)]]
y = torch.ones(2)
res = control_flow.map(f, xs, y)
expected = _fake_map(f, xs, y)
self.assertEqual(len(res), 1)
self.assertEqual(len(res[0]), 1)
self.assertEqual(expected, res)
def test_map_dict_in_out(self):
def f(x, y):
return {"c": x["a"]["b"] + y}
xs = {"a": {"b": torch.ones(3, 2, 2)}}
y = torch.ones(2)
res = control_flow.map(f, xs, y)
expected = _fake_map(f, xs, y)
self.assertEqual(len(res), 1)
self.assertTrue("c" in res)
self.assertEqual(expected, res)
def test_map_autograd_simple(self):
def f(x, y):
return x.sin().cos() * y.cos().sin()
xs = torch.ones(3, 2, 2, requires_grad=True)
y = torch.ones(2, requires_grad=True)
res = control_flow.map(f, xs, y)
expected_res = _fake_map(f, xs, y)
grad_out = torch.ones_like(res)
grads = torch.autograd.grad(res, (xs, y), grad_out)
expected_grads = torch.autograd.grad(expected_res, (xs, y), grad_out)
self.assertEqual(expected_res, res)
self.assertEqual(expected_grads, grads)
def test_map_autograd_simple_partial_grad(self):
def f(x, y):
return x.sin().cos() * y.cos().sin()
xs = torch.ones(3, 2, 2, requires_grad=True)
# Disable the gradient computation for y
y = torch.ones(2, requires_grad=False)
res = control_flow.map(f, xs, y)
expected_res = _fake_map(f, xs, y)
grad_out = torch.ones_like(res)
grads = torch.autograd.grad(res, (xs,), grad_out)
expected_grads = torch.autograd.grad(expected_res, (xs,), grad_out)
self.assertEqual(expected_res, res)
self.assertEqual(expected_grads, grads)
def test_map_autograd_no_grad_output(self):
def f(x, y):
return x[0].sin().cos() + y, y.cos().sin()
xs = [torch.ones(3, 2, 2, requires_grad=True), torch.ones(3, 3)]
# Disable the gradient computation for y
y = torch.ones(2, requires_grad=False)
res = control_flow.map(f, xs, y)
expected_res = _fake_map(f, xs, y)
grad_out = torch.ones_like(res[0])
grads = torch.autograd.grad(res[0], (xs[0],), grad_out)
expected_grads = torch.autograd.grad(expected_res[0], (xs[0],), grad_out)
self.assertEqual(expected_res, res)
self.assertEqual(expected_grads, grads)
def test_map_autograd_nested_list(self):
import torch.utils._pytree as pytree
def f(x, y):
a, b = x
c, d = a
return [[b.sin() * c.cos()], d.sin() * y.cos()]
def fwbw(map_op, f, x, y):
z = map_op(f, x, y)
flat_x, _ = pytree.tree_flatten(x)
flat_z, _ = pytree.tree_flatten(z)
grads = torch.autograd.grad(flat_z, flat_x, [torch.ones_like(z) for z in flat_z])
return z, grads
x = [[torch.randn(3, 2, 2, requires_grad=True), torch.randn(3, 2, 1, requires_grad=True)],
torch.ones(3, 1, 2, requires_grad=True)]
y = torch.ones(1, requires_grad=True)
true_outs = fwbw(control_flow.map, f, x, y)
fake_outs = fwbw(_fake_map, f, x, y)
self.assertEqual(true_outs, fake_outs)
class TestControlFlowTraced(TestCase):
@ -702,17 +841,15 @@ class TestControlFlowTraced(TestCase):
):
make_fx(f, tracing_mode="fake")(x, torch.tensor(False))
def check_map_graph(self, gm, key):
def check_map_count(self, gm, op_count):
i = 0
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.map:
i += 1
self.assertEqual(
node.meta[key].shape[0], node.args[1].meta[key].shape[0]
)
self.assertEqual(i, 1)
for m in gm.modules():
for node in m.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.map_impl:
i += 1
self.assertEqual(i, op_count)
def test_map_real(self):
def test_tracing_map_real(self):
def f(x, y):
return x + y
@ -724,9 +861,9 @@ class TestControlFlowTraced(TestCase):
y = torch.randn(2)
res = gm(x, y)
self.assertEqual(res, g(x, y))
self.check_map_graph(gm, "tensor_meta")
self.check_map_count(gm, 1)
def test_map_symbolic(self):
def test_tracing_map_symbolic_simple(self):
def f(x, y):
return x + y
@ -738,7 +875,140 @@ class TestControlFlowTraced(TestCase):
y = torch.randn(2)
res = gm(x, y)
self.assertEqual(res, g(x, y))
self.check_map_graph(gm, "val")
self.check_map_count(gm, 1)
def test_tracing_map_symbolic_list(self):
def f(x, y):
return [x[0][0] + y, x[1] * y]
def g(xs, y, z):
out = control_flow.map(f, xs, y)
return out[0] + z, out[1] * z
example_x = [[torch.ones(3, 4, 5)], torch.ones(3, 4, 5)]
gm = make_fx(g, tracing_mode="symbolic")(example_x, torch.ones(5), torch.ones(5))
x = [[torch.randn(4, 5, 6)], torch.ones(4, 5, 6)]
y = torch.randn(6)
z = torch.ones(6)
res = gm(x, y, z)
self.assertEqual(res, g(x, y, z))
self.check_map_count(gm, 1)
def test_tracing_map_symbolic_dict(self):
def f(x, y):
return {"d": x["b"]["a"] + y, "e": x["c"] * y}
def g(xs, y, z):
out = control_flow.map(f, xs, y)
return {"f": out["d"] + z, "g": out["e"] * z}
example_x = {"b": {"a": torch.ones(3, 4, 5)}, "c": torch.ones(3, 4, 5)}
gm = make_fx(g, tracing_mode="symbolic")(example_x, torch.ones(5), torch.ones(5))
x = {"b": {"a": torch.randn(4, 5, 6)}, "c": torch.ones(4, 5, 6)}
y = torch.randn(6)
z = torch.ones(6)
res = gm(x, y, z)
self.assertEqual(res, g(x, y, z))
self.check_map_count(gm, 1)
def test_tracing_map_autograd_symbolic_simple(self):
def f(x, y):
return x + y
def g(xs, y):
out = control_flow.map(f, xs, y)
return torch.autograd.grad(out, (xs, y), torch.ones_like(out))
gm = make_fx(g, tracing_mode="symbolic")(torch.ones(3, 4, 5, requires_grad=True), torch.ones(5, requires_grad=True))
x = torch.randn(4, 5, 6, requires_grad=True)
y = torch.randn(6, requires_grad=True)
res = gm(x, y)
self.assertEqual(res, g(x, y))
self.check_map_count(gm, 2)
def test_tracing_map_autograd_symbolic_list(self):
import torch.utils._pytree as pytree
def f(x, y):
return [x[0].cos() + y.sin(), x[1].sin() * y.cos()]
def g(xs, y):
out = control_flow.map(f, xs, y)
flat_out, _ = pytree.tree_flatten(out)
flat_inp, _ = pytree.tree_flatten((xs, y))
requires_grad_inp = [inp for inp in flat_inp if inp.requires_grad]
return torch.autograd.grad(flat_out, requires_grad_inp, [torch.ones_like(out) for out in flat_out])
gm = make_fx(g, tracing_mode="symbolic")(
[torch.ones(3, 4, 5), torch.ones(3, 4, 5, requires_grad=True)],
torch.ones(5, requires_grad=True))
x = [torch.randn(4, 5, 6), torch.ones(4, 5, 6, requires_grad=True)]
y = torch.randn(6, requires_grad=True)
res = gm(x, y)
self.assertEqual(res, g(x, y))
self.check_map_count(gm, 2)
def test_tracing_map_autograd_symbolic_dict(self):
def f(x, y):
return [x["a"] + y, x["b"] * y]
def g(xs, y):
out = control_flow.map(f, xs, y)
flat_out, _ = pytree.tree_flatten(out)
flat_inp, _ = pytree.tree_flatten((xs, y))
requires_grad_inp = [inp for inp in flat_inp if inp.requires_grad]
return torch.autograd.grad(flat_out, requires_grad_inp, [torch.ones_like(out) for out in flat_out])
traced_x = {"a": torch.ones(3, 4, 5, requires_grad=True), "b": torch.ones(3, 4, 5, requires_grad=True)}
gm = make_fx(g, tracing_mode="symbolic")(traced_x, torch.ones(5, requires_grad=True))
x = {"a": torch.randn(4, 5, 6, requires_grad=True), "b": torch.ones(4, 5, 6, requires_grad=True)}
y = torch.randn(6, requires_grad=True)
res = gm(x, y)
self.assertEqual(res, g(x, y))
self.check_map_count(gm, 2)
def test_tracing_map_autograd_aot_functionalized(self):
def inner(x, y):
z = x - 1
z.add_(1)
return z * y
def f(xs, y):
res = control_flow.map(inner, xs, y)
grads = torch.autograd.grad(res, (xs, y), torch.ones_like(res))
return grads
def f_wrapper(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
torch._enable_functionalization(reapply_views=False)
try:
return pytree.tree_map(from_fun, func(*args, **kwargs))
finally:
torch._disable_functionalization()
return wrapper
example_inputs = (torch.ones(3, 2, 4, requires_grad=True), torch.ones(2, 4, requires_grad=True))
gm = make_fx(f, tracing_mode="symbolic")(*example_inputs)
fgm = make_fx(f_wrapper(f), tracing_mode="symbolic")(*example_inputs)
xs = torch.ones(3, 4, 5, requires_grad=True)
y = torch.ones(4, 5, requires_grad=True)
self.assertEqual(gm(xs, y), f(xs, y))
def count_mutable(gm):
c = 0
for node in gm.graph.nodes:
if node.op == "call_function":
if node.target == torch.ops.map_impl:
c += count_mutable(getattr(gm, str(node.args[0])))
elif schema := getattr(node.target, "_schema", None):
c += int(schema.is_mutable)
return c
self.assertEqual(count_mutable(fgm), 0)
# One for forward, one for recompuation logic in backward
self.assertEqual(count_mutable(gm), 2)
def test_map_functionalized(self):
def map_fn(x, y):
@ -762,6 +1032,7 @@ class TestControlFlowTraced(TestCase):
for node in gm.body_graph_0.graph.nodes:
if node.op == "call_function":
self.assertTrue(not node.target._schema.is_mutable)
self.check_map_count(gm, 1)
def test_map_functionalized_aot_func(self):
def map_fn(x, y):
@ -848,11 +1119,11 @@ class TestControlFlowTraced(TestCase):
torch.tensor(True), torch.ones(3, 2, 4), torch.ones(4)
)
pred = torch.tensor(False)
x = torch.randn(3, 2, 2)
y = torch.randn(2)
x = torch.randn(3, 2, 4)
y = torch.randn(4)
res = gm(pred, x, y)
self.assertEqual(res, g(pred, x, y))
self.check_map_graph(gm, "tensor_meta")
self.check_map_count(gm, 1)
def test_nested_map_cond_symbolic(self):
def true_fn(x, y):
@ -875,7 +1146,7 @@ class TestControlFlowTraced(TestCase):
y = torch.randn(2)
res = gm(pred, x, y)
self.assertEqual(res, g(pred, x, y))
self.check_map_graph(gm, "val")
self.check_map_count(gm, 1)
def test_nested_cond_map_cond_symbolic(self):
@ -909,6 +1180,7 @@ class TestControlFlowTraced(TestCase):
y = torch.randn(2)
res = gm(p, pred, xs, y)
self.assertEqual(res, main(p, pred, xs, y))
self.check_map_count(gm, 2)
def test_cond_with_sym_pred(self):
def true_fn(x):

View File

@ -5,6 +5,7 @@ import torch
from torch.testing._internal.common_utils import TestGradients, run_tests
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.control_flow_opinfo_db import control_flow_opinfo_db
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, ops, OpDTypes)
@ -17,7 +18,7 @@ _gradcheck_ops = partial(ops, dtypes=OpDTypes.supported,
class TestBwdGradients(TestGradients):
# Tests that gradients are computed correctly
@_gradcheck_ops(op_db)
@_gradcheck_ops(op_db + control_flow_opinfo_db)
def test_fn_grad(self, device, dtype, op):
# This is verified by test_dtypes in test_ops.py
if dtype not in op.supported_backward_dtypes(torch.device(device).type):
@ -51,7 +52,7 @@ class TestBwdGradients(TestGradients):
self._grad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
# Test that gradients of gradients are computed correctly
@_gradcheck_ops(op_db)
@_gradcheck_ops(op_db + control_flow_opinfo_db)
def test_fn_gradgrad(self, device, dtype, op):
self._skip_helper(op, device, dtype)
if not op.supports_gradgrad:

View File

@ -16,6 +16,7 @@ from torch.fx.experimental.symbolic_shapes import (
constrain_range, guard_int, GuardOnDataDependentSymNode
)
from torch.testing._internal.custom_op_db import custom_op_db
from torch.testing._internal.control_flow_opinfo_db import control_flow_opinfo_db
from torch.testing._internal.common_device_type import ops
from torch._C import _disabled_torch_function_impl
from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter, get_isolated_graphmodule
@ -1610,17 +1611,17 @@ def _test_make_fx_helper(self, device, dtype, op, tracing_mode, inplace=False):
self.assertEqual(new_out, old_out)
class TestProxyTensorOpInfo(TestCase):
@ops(op_db + custom_op_db, allowed_dtypes=(torch.float,))
@ops(op_db + custom_op_db + control_flow_opinfo_db, allowed_dtypes=(torch.float,))
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_exhaustive', make_fx_failures)
def test_make_fx_exhaustive(self, device, dtype, op):
_test_make_fx_helper(self, device, dtype, op, "real")
@ops(op_db + custom_op_db, allowed_dtypes=(torch.float,))
@ops(op_db + custom_op_db + control_flow_opinfo_db, allowed_dtypes=(torch.float,))
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_fake_exhaustive', make_fx_failures.union(fake_tensor_failures))
def test_make_fx_fake_exhaustive(self, device, dtype, op):
_test_make_fx_helper(self, device, dtype, op, "fake")
@ops(op_db + custom_op_db, allowed_dtypes=(torch.float,))
@ops(op_db + custom_op_db + control_flow_opinfo_db, allowed_dtypes=(torch.float,))
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive',
make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | outplace_symbolic_tensor_failures)
def test_make_fx_symbolic_exhaustive(self, device, dtype, op):

View File

@ -0,0 +1,75 @@
import torch
import functools
from torch.testing import make_tensor
from functorch.experimental.control_flow import map
from torch.testing._internal.opinfo.core import (
OpInfo,
SampleInput,
)
from torch.testing._internal.common_dtype import all_types_and
def sample_inputs_map(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(
make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
yield SampleInput([make_arg(2, 2, 2, low=0.1, high=2), make_arg(2, 2, 2, low=0.1, high=2)],
args=(make_arg(1, low=0.1, high=2), make_arg(1, low=0.1, high=2)))
def inner_f(x, y0, y1):
return [x[0].cos().add_(1.) * y0, (x[1] + y1.sin()).cos_().view(x[1].size())]
def simple_map(xs, y0, y1):
def f(x, y0, y1):
return inner_f(x, y0, y1)
return map(f, xs, y0, y1)
def nested_map(xs, y0, y1):
def f1(xx, y0, y1):
def f2(x, y0, y1):
return inner_f(x, y0, y1)
return map(f2, xx, y0, y1)
return map(f1, xs, y0, y1)
def triple_nested_map(xs, y0, y1):
def f0(xs, y0, y1):
def f1(xx, y0, y1):
def f2(x, y0, y1):
return inner_f(x, y0, y1)
return map(f2, xx, y0, y1)
return map(f1, xs, y0, y1)
return map(f0, xs, y0, y1)
control_flow_opinfo_db = [
OpInfo(
"MapControlflowOp",
op=simple_map,
sample_inputs_func=sample_inputs_map,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
check_batched_grad=False,
check_batched_gradgrad=False,
check_batched_forward_grad=False,
check_inplace_batched_forward_grad=False,
),
OpInfo(
"NestedMapControlflowOp",
op=nested_map,
sample_inputs_func=sample_inputs_map,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
check_batched_grad=False,
check_batched_gradgrad=False,
check_batched_forward_grad=False,
check_inplace_batched_forward_grad=False,
),
OpInfo(
"TripleNestedMapControlflowOp",
op=triple_nested_map,
sample_inputs_func=sample_inputs_map,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
check_batched_grad=False,
check_batched_gradgrad=False,
check_batched_forward_grad=False,
check_inplace_batched_forward_grad=False,
)
]