mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Support map autograd and pytree in/out. (#101633)
Rebased https://github.com/pytorch/pytorch/pull/100494 and added dummy AOTConfig. This PR adds autograd and pytree support for map operator. Implementation-wise: 1. We temporarily make two HigherOrderOperators, "map" and "map_impl": - "map" is user-facing. Currently, it unwraps the pytrees in inputs and create a flat_fn for it. Dynamo currently cannot deal with pytree.tree_flatten and pytree.tree_unflatten, we therefore make it a HigherOrderOperator to trigger dynamo logic of handling HigherOrderOperators. - "map_impl" is the actual operator that works with the rest of torch subsystems such as functionalization, make_fx. It accepts flattend arguments, and a num_mapped_args integer denoting how many of the flattend arguments need to mapped i.e. their first dimension will be unstacked. 2. We create the forward and backward graph in autograd key and call torch.autograd.Function. Currently, the backward graph is recomputation-based and we need to partition the joint graph in the future to be more efficient. Example traced graphs for map operators: ### Case 1: simple f and autograd ```python def f(x, y): return x + y def g(xs, y): out = control_flow.map(f, xs, y) return torch.autograd.grad(out, (xs, y), torch.ones_like(out)) gm = make_fx(g, tracing_mode="symbolic")(torch.ones(3, 4, 5, requires_grad=True), torch.ones(5, requires_grad=True)) # gm.print_readable() produces following: class g(torch.nn.Module): def forward(self, xs_1: f32[3, s1, s2], y_1: f32[s2]): # No stacktrace found for following nodes body_graph_0 = self.body_graph_0 map_impl = torch.ops.map_impl(body_graph_0, 1, xs_1, y_1); body_graph_0 = None getitem: f32[3, s1, s2] = map_impl[0]; map_impl = None ones_like: f32[3, s1, s2] = torch.ops.aten.ones_like.default(getitem, pin_memory = False) is_same_size = torch.ops.aten.is_same_size.default(getitem, ones_like); getitem = None body_graph_1 = self.body_graph_1 map_impl_1 = torch.ops.map_impl(body_graph_1, 2, xs_1, ones_like, y_1); body_graph_1 = xs_1 = ones_like = None getitem_1 = map_impl_1[0] getitem_2: f32[3, s1, s2] = map_impl_1[1] getitem_3: f32[3, s2] = map_impl_1[2]; map_impl_1 = None sum_1: f32[1, s2] = torch.ops.aten.sum.dim_IntList(getitem_3, [0], True); getitem_3 = None sym_size: Sym(s2) = torch.ops.aten.sym_size(y_1, 0); y_1 = None view: f32[s2] = torch.ops.aten.view.default(sum_1, [sym_size]); sum_1 = sym_size = None return (getitem_2, view) class <lambda>(torch.nn.Module): def forward(self, arg0_1, arg1_1: f32[s1, s2], arg2_1: f32[s2]): # No stacktrace found for following nodes add: f32[s1, s2] = torch.ops.aten.add.Tensor(arg1_1, arg2_1); arg1_1 = arg2_1 = None return [add] class <lambda>(torch.nn.Module): def forward(self, arg0_1, arg1_1: f32[s1, s2], arg2_1: f32[s1, s2], arg3_1: f32[s2]): # No stacktrace found for following nodes add: f32[s1, s2] = torch.ops.aten.add.Tensor(arg1_1, arg3_1); arg1_1 = None is_same_size = torch.ops.aten.is_same_size.default(add, arg2_1); add = None sum_1: f32[1, s2] = torch.ops.aten.sum.dim_IntList(arg2_1, [0], True) sym_size: Sym(s2) = torch.ops.aten.sym_size(arg3_1, 0); arg3_1 = None view: f32[s2] = torch.ops.aten.view.default(sum_1, [sym_size]); sum_1 = sym_size = None return [None, arg2_1, view] ``` ### Case 2: list input/output f and autograd ```python def f(x, y): return [x[0].cos() + y.sin(), x[1].sin() * y.cos()] def g(xs, y): out = control_flow.map(f, xs, y) flat_out, _ = pytree.tree_flatten(out) flat_inp, _ = pytree.tree_flatten((xs, y)) requires_grad_inp = [inp for inp in flat_inp if inp.requires_grad] return torch.autograd.grad(flat_out, requires_grad_inp, [torch.ones_like(out) for out in flat_out]) gm = make_fx(g, tracing_mode="symbolic")( [torch.ones(3, 4, 5), torch.ones(3, 4, 5, requires_grad=True)], torch.ones(5, requires_grad=True)) # gm.print_readable() produces following: class g(torch.nn.Module): def forward(self, xs, y): xs_1: f32[3, s1, s2], xs_2: f32[3, s1, s2], y_1: f32[s2], = fx_pytree.tree_flatten_spec([xs, y], self._in_spec) # No stacktrace found for following nodes body_graph_0 = self.body_graph_0 map_impl = torch.ops.map_impl(body_graph_0, 2, xs_1, xs_2, y_1); body_graph_0 = None getitem: f32[3, s1, s2] = map_impl[0] getitem_1: f32[3, s1, s2] = map_impl[1]; map_impl = None ones_like: f32[3, s1, s2] = torch.ops.aten.ones_like.default(getitem, pin_memory = False) ones_like_1: f32[3, s1, s2] = torch.ops.aten.ones_like.default(getitem_1, pin_memory = False) is_same_size = torch.ops.aten.is_same_size.default(getitem, ones_like); getitem = None is_same_size_1 = torch.ops.aten.is_same_size.default(getitem_1, ones_like_1); getitem_1 = None body_graph_1 = self.body_graph_1 map_impl_1 = torch.ops.map_impl(body_graph_1, 4, xs_1, xs_2, ones_like, ones_like_1, y_1); body_graph_1 = xs_1 = xs_2 = ones_like = ones_like_1 = None getitem_2 = map_impl_1[0] getitem_3 = map_impl_1[1] getitem_4: f32[3, s1, s2] = map_impl_1[2] getitem_5: f32[3, s2] = map_impl_1[3]; map_impl_1 = None sum_1: f32[1, s2] = torch.ops.aten.sum.dim_IntList(getitem_5, [0], True); getitem_5 = None sym_size: Sym(s2) = torch.ops.aten.sym_size(y_1, 0); y_1 = None view: f32[s2] = torch.ops.aten.view.default(sum_1, [sym_size]); sum_1 = sym_size = None return pytree.tree_unflatten([getitem_4, view], self._out_spec) class <lambda>(torch.nn.Module): def forward(self, arg0_1, arg1_1: f32[s1, s2], arg2_1: f32[s1, s2], arg3_1: f32[s2]): # No stacktrace found for following nodes cos: f32[s1, s2] = torch.ops.aten.cos.default(arg1_1); arg1_1 = None sin: f32[s2] = torch.ops.aten.sin.default(arg3_1) add: f32[s1, s2] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None sin_1: f32[s1, s2] = torch.ops.aten.sin.default(arg2_1); arg2_1 = None cos_1: f32[s2] = torch.ops.aten.cos.default(arg3_1); arg3_1 = None mul: f32[s1, s2] = torch.ops.aten.mul.Tensor(sin_1, cos_1); sin_1 = cos_1 = None return [add, mul] class <lambda>(torch.nn.Module): def forward(self, arg0_1, arg1_1: f32[s1, s2], arg2_1: f32[s1, s2], arg3_1: f32[s1, s2], arg4_1: f32[s1, s2], arg5_1: f32[s2]): # No stacktrace found for following nodes cos: f32[s1, s2] = torch.ops.aten.cos.default(arg1_1); arg1_1 = None sin: f32[s2] = torch.ops.aten.sin.default(arg5_1) add: f32[s1, s2] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None sin_1: f32[s1, s2] = torch.ops.aten.sin.default(arg2_1) cos_1: f32[s2] = torch.ops.aten.cos.default(arg5_1) mul: f32[s1, s2] = torch.ops.aten.mul.Tensor(sin_1, cos_1) is_same_size = torch.ops.aten.is_same_size.default(add, arg3_1); add = None is_same_size_1 = torch.ops.aten.is_same_size.default(mul, arg4_1); mul = None mul_1: f32[s1, s2] = torch.ops.aten.mul.Tensor(arg4_1, sin_1); sin_1 = None mul_2: f32[s1, s2] = torch.ops.aten.mul.Tensor(arg4_1, cos_1); arg4_1 = cos_1 = None sum_1: f32[1, s2] = torch.ops.aten.sum.dim_IntList(mul_1, [0], True); mul_1 = None sym_size: Sym(s2) = torch.ops.aten.sym_size(arg5_1, 0) view: f32[s2] = torch.ops.aten.view.default(sum_1, [sym_size]); sum_1 = None # sin_2: f32[s2] = torch.ops.aten.sin.default(arg5_1) neg: f32[s2] = torch.ops.aten.neg.default(sin_2); sin_2 = None mul_3: f32[s2] = torch.ops.aten.mul.Tensor(view, neg); view = neg = None cos_2: f32[s1, s2] = torch.ops.aten.cos.default(arg2_1); arg2_1 = None mul_4: f32[s1, s2] = torch.ops.aten.mul.Tensor(mul_2, cos_2); mul_2 = cos_2 = None sum_2: f32[1, s2] = torch.ops.aten.sum.dim_IntList(arg3_1, [0], True); arg3_1 = None view_1: f32[s2] = torch.ops.aten.view.default(sum_2, [sym_size]); sum_2 = sym_size = None cos_3: f32[s2] = torch.ops.aten.cos.default(arg5_1); arg5_1 = None mul_5: f32[s2] = torch.ops.aten.mul.Tensor(view_1, cos_3); view_1 = cos_3 = None add_1: f32[s2] = torch.ops.aten.add.Tensor(mul_3, mul_5); mul_3 = mul_5 = None return [None, None, mul_4, add_1] ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/101633 Approved by: https://github.com/zou3519
This commit is contained in:
@ -5,6 +5,7 @@ import torch
|
||||
|
||||
from torch.testing._internal.common_utils import TestGradients, run_tests
|
||||
from torch.testing._internal.common_methods_invocations import op_db
|
||||
from torch.testing._internal.control_flow_opinfo_db import control_flow_opinfo_db
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(instantiate_device_type_tests, ops, OpDTypes)
|
||||
|
||||
@ -17,7 +18,7 @@ _gradcheck_ops = partial(ops, dtypes=OpDTypes.supported,
|
||||
|
||||
class TestBwdGradients(TestGradients):
|
||||
# Tests that gradients are computed correctly
|
||||
@_gradcheck_ops(op_db)
|
||||
@_gradcheck_ops(op_db + control_flow_opinfo_db)
|
||||
def test_fn_grad(self, device, dtype, op):
|
||||
# This is verified by test_dtypes in test_ops.py
|
||||
if dtype not in op.supported_backward_dtypes(torch.device(device).type):
|
||||
@ -51,7 +52,7 @@ class TestBwdGradients(TestGradients):
|
||||
self._grad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
|
||||
|
||||
# Test that gradients of gradients are computed correctly
|
||||
@_gradcheck_ops(op_db)
|
||||
@_gradcheck_ops(op_db + control_flow_opinfo_db)
|
||||
def test_fn_gradgrad(self, device, dtype, op):
|
||||
self._skip_helper(op, device, dtype)
|
||||
if not op.supports_gradgrad:
|
||||
|
Reference in New Issue
Block a user