Support map autograd and pytree in/out (#100494)

This PR adds autograd and pytree support for map operator.

Implementation-wise:

1. We temporarily make two HigherOrderOperators, "map" and "map_impl":
- "map" is user-facing. Currently, it unwraps the pytrees in inputs and create a flat_fn for it. Dynamo currently cannot deal with pytree.tree_flatten and pytree.tree_unflatten, we therefore make it a HigherOrderOperator to trigger dynamo logic of handling HigherOrderOperators.
- "map_impl" is the actual operator that works with the rest of torch subsystems such as functionalization, make_fx. It accepts flattend arguments, and a num_mapped_args integer denoting how many of the flattend arguments need to mapped i.e. their first dimension will be unstacked.

2. We create the forward and backward graph in autograd key and call torch.autograd.Function. Currently, the backward graph is recomputation-based and we need to partition the joint graph in the future to be more efficient.

Example traced graphs for map operators:
### Case 1: simple f and autograd
```python
def f(x, y):
    return x + y

def g(xs, y):
    out = control_flow.map(f, xs, y)
    return torch.autograd.grad(out, (xs, y), torch.ones_like(out))

gm = make_fx(g, tracing_mode="symbolic")(torch.ones(3, 4, 5, requires_grad=True), torch.ones(5, requires_grad=True))
# gm.print_readable() produces following:
class g(torch.nn.Module):
    def forward(self, xs_1: f32[3, s1, s2], y_1: f32[s2]):
        # No stacktrace found for following nodes
        body_graph_0 = self.body_graph_0
        map_impl = torch.ops.map_impl(body_graph_0, 1, xs_1, y_1);  body_graph_0 = None
        getitem: f32[3, s1, s2] = map_impl[0];  map_impl = None
        ones_like: f32[3, s1, s2] = torch.ops.aten.ones_like.default(getitem, pin_memory = False)
        is_same_size = torch.ops.aten.is_same_size.default(getitem, ones_like);  getitem = None
        body_graph_1 = self.body_graph_1
        map_impl_1 = torch.ops.map_impl(body_graph_1, 2, xs_1, ones_like, y_1);  body_graph_1 = xs_1 = ones_like = None
        getitem_1 = map_impl_1[0]
        getitem_2: f32[3, s1, s2] = map_impl_1[1]
        getitem_3: f32[3, s2] = map_impl_1[2];  map_impl_1 = None
        sum_1: f32[1, s2] = torch.ops.aten.sum.dim_IntList(getitem_3, [0], True);  getitem_3 = None
        sym_size: Sym(s2) = torch.ops.aten.sym_size(y_1, 0);  y_1 = None
        view: f32[s2] = torch.ops.aten.view.default(sum_1, [sym_size]);  sum_1 = sym_size = None
        return (getitem_2, view)

    class <lambda>(torch.nn.Module):
        def forward(self, arg0_1, arg1_1: f32[s1, s2], arg2_1: f32[s2]):
            # No stacktrace found for following nodes
            add: f32[s1, s2] = torch.ops.aten.add.Tensor(arg1_1, arg2_1);  arg1_1 = arg2_1 = None
            return [add]

    class <lambda>(torch.nn.Module):
        def forward(self, arg0_1, arg1_1: f32[s1, s2], arg2_1: f32[s1, s2], arg3_1: f32[s2]):
            # No stacktrace found for following nodes
            add: f32[s1, s2] = torch.ops.aten.add.Tensor(arg1_1, arg3_1);  arg1_1 = None
            is_same_size = torch.ops.aten.is_same_size.default(add, arg2_1);  add = None
            sum_1: f32[1, s2] = torch.ops.aten.sum.dim_IntList(arg2_1, [0], True)
            sym_size: Sym(s2) = torch.ops.aten.sym_size(arg3_1, 0);  arg3_1 = None
            view: f32[s2] = torch.ops.aten.view.default(sum_1, [sym_size]);  sum_1 = sym_size = None
            return [None, arg2_1, view]
```
### Case 2: list input/output f and autograd
```python
def f(x, y):
    return [x[0].cos() + y.sin(), x[1].sin() * y.cos()]

def g(xs, y):
    out = control_flow.map(f, xs, y)
    flat_out, _ = pytree.tree_flatten(out)
    flat_inp, _ = pytree.tree_flatten((xs, y))
    requires_grad_inp = [inp for inp in flat_inp if inp.requires_grad]
    return torch.autograd.grad(flat_out, requires_grad_inp, [torch.ones_like(out) for out in flat_out])

gm = make_fx(g, tracing_mode="symbolic")(
    [torch.ones(3, 4, 5), torch.ones(3, 4, 5, requires_grad=True)],
    torch.ones(5, requires_grad=True))

# gm.print_readable() produces following:
class g(torch.nn.Module):
    def forward(self, xs, y):
        xs_1: f32[3, s1, s2], xs_2: f32[3, s1, s2], y_1: f32[s2], = fx_pytree.tree_flatten_spec([xs, y], self._in_spec)
        # No stacktrace found for following nodes
        body_graph_0 = self.body_graph_0
        map_impl = torch.ops.map_impl(body_graph_0, 2, xs_1, xs_2, y_1);  body_graph_0 = None
        getitem: f32[3, s1, s2] = map_impl[0]
        getitem_1: f32[3, s1, s2] = map_impl[1];  map_impl = None
        ones_like: f32[3, s1, s2] = torch.ops.aten.ones_like.default(getitem, pin_memory = False)
        ones_like_1: f32[3, s1, s2] = torch.ops.aten.ones_like.default(getitem_1, pin_memory = False)
        is_same_size = torch.ops.aten.is_same_size.default(getitem, ones_like);  getitem = None
        is_same_size_1 = torch.ops.aten.is_same_size.default(getitem_1, ones_like_1);  getitem_1 = None
        body_graph_1 = self.body_graph_1
        map_impl_1 = torch.ops.map_impl(body_graph_1, 4, xs_1, xs_2, ones_like, ones_like_1, y_1);  body_graph_1 = xs_1 = xs_2 = ones_like = ones_like_1 = None
        getitem_2 = map_impl_1[0]
        getitem_3 = map_impl_1[1]
        getitem_4: f32[3, s1, s2] = map_impl_1[2]
        getitem_5: f32[3, s2] = map_impl_1[3];  map_impl_1 = None
        sum_1: f32[1, s2] = torch.ops.aten.sum.dim_IntList(getitem_5, [0], True);  getitem_5 = None
        sym_size: Sym(s2) = torch.ops.aten.sym_size(y_1, 0);  y_1 = None
        view: f32[s2] = torch.ops.aten.view.default(sum_1, [sym_size]);  sum_1 = sym_size = None
        return pytree.tree_unflatten([getitem_4, view], self._out_spec)

    class <lambda>(torch.nn.Module):
        def forward(self, arg0_1, arg1_1: f32[s1, s2], arg2_1: f32[s1, s2], arg3_1: f32[s2]):
            # No stacktrace found for following nodes
            cos: f32[s1, s2] = torch.ops.aten.cos.default(arg1_1);  arg1_1 = None
            sin: f32[s2] = torch.ops.aten.sin.default(arg3_1)
            add: f32[s1, s2] = torch.ops.aten.add.Tensor(cos, sin);  cos = sin = None
            sin_1: f32[s1, s2] = torch.ops.aten.sin.default(arg2_1);  arg2_1 = None
            cos_1: f32[s2] = torch.ops.aten.cos.default(arg3_1);  arg3_1 = None
            mul: f32[s1, s2] = torch.ops.aten.mul.Tensor(sin_1, cos_1);  sin_1 = cos_1 = None
            return [add, mul]

    class <lambda>(torch.nn.Module):
        def forward(self, arg0_1, arg1_1: f32[s1, s2], arg2_1: f32[s1, s2], arg3_1: f32[s1, s2], arg4_1: f32[s1, s2], arg5_1: f32[s2]):
            # No stacktrace found for following nodes
            cos: f32[s1, s2] = torch.ops.aten.cos.default(arg1_1);  arg1_1 = None
            sin: f32[s2] = torch.ops.aten.sin.default(arg5_1)
            add: f32[s1, s2] = torch.ops.aten.add.Tensor(cos, sin);  cos = sin = None
            sin_1: f32[s1, s2] = torch.ops.aten.sin.default(arg2_1)
            cos_1: f32[s2] = torch.ops.aten.cos.default(arg5_1)
            mul: f32[s1, s2] = torch.ops.aten.mul.Tensor(sin_1, cos_1)
            is_same_size = torch.ops.aten.is_same_size.default(add, arg3_1);  add = None
            is_same_size_1 = torch.ops.aten.is_same_size.default(mul, arg4_1);  mul = None
            mul_1: f32[s1, s2] = torch.ops.aten.mul.Tensor(arg4_1, sin_1);  sin_1 = None
            mul_2: f32[s1, s2] = torch.ops.aten.mul.Tensor(arg4_1, cos_1);  arg4_1 = cos_1 = None
            sum_1: f32[1, s2] = torch.ops.aten.sum.dim_IntList(mul_1, [0], True);  mul_1 = None
            sym_size: Sym(s2) = torch.ops.aten.sym_size(arg5_1, 0)
            view: f32[s2] = torch.ops.aten.view.default(sum_1, [sym_size]);  sum_1 = None

            #
            sin_2: f32[s2] = torch.ops.aten.sin.default(arg5_1)
            neg: f32[s2] = torch.ops.aten.neg.default(sin_2);  sin_2 = None
            mul_3: f32[s2] = torch.ops.aten.mul.Tensor(view, neg);  view = neg = None
            cos_2: f32[s1, s2] = torch.ops.aten.cos.default(arg2_1);  arg2_1 = None
            mul_4: f32[s1, s2] = torch.ops.aten.mul.Tensor(mul_2, cos_2);  mul_2 = cos_2 = None
            sum_2: f32[1, s2] = torch.ops.aten.sum.dim_IntList(arg3_1, [0], True);  arg3_1 = None
            view_1: f32[s2] = torch.ops.aten.view.default(sum_2, [sym_size]);  sum_2 = sym_size = None
            cos_3: f32[s2] = torch.ops.aten.cos.default(arg5_1);  arg5_1 = None
            mul_5: f32[s2] = torch.ops.aten.mul.Tensor(view_1, cos_3);  view_1 = cos_3 = None
            add_1: f32[s2] = torch.ops.aten.add.Tensor(mul_3, mul_5);  mul_3 = mul_5 = None
            return [None, None, mul_4, add_1]
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100494
Approved by: https://github.com/zou3519
This commit is contained in:
ydwu4
2023-05-16 22:05:11 +00:00
committed by PyTorch MergeBot
parent 552b712f80
commit b8fa41be9d
7 changed files with 617 additions and 95 deletions

View File

@ -5,6 +5,7 @@ import torch
from torch.testing._internal.common_utils import TestGradients, run_tests
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.control_flow_opinfo_db import control_flow_opinfo_db
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, ops, OpDTypes)
@ -17,7 +18,7 @@ _gradcheck_ops = partial(ops, dtypes=OpDTypes.supported,
class TestBwdGradients(TestGradients):
# Tests that gradients are computed correctly
@_gradcheck_ops(op_db)
@_gradcheck_ops(op_db + control_flow_opinfo_db)
def test_fn_grad(self, device, dtype, op):
# This is verified by test_dtypes in test_ops.py
if dtype not in op.supported_backward_dtypes(torch.device(device).type):
@ -51,7 +52,7 @@ class TestBwdGradients(TestGradients):
self._grad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
# Test that gradients of gradients are computed correctly
@_gradcheck_ops(op_db)
@_gradcheck_ops(op_db + control_flow_opinfo_db)
def test_fn_gradgrad(self, device, dtype, op):
self._skip_helper(op, device, dtype)
if not op.supports_gradgrad: