mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert " [while_loop][autograd] support autograd_key of while_loop (#160483)"
This reverts commit 2b8a83901c58a0858ea9e4ce00055f48e6ed164c. Reverted https://github.com/pytorch/pytorch/pull/160483 on behalf of https://github.com/huydhn due to Sorry for reverting your PR, but some trunk tests are failing either from this PR or the previous one in the stack ([comment](https://github.com/pytorch/pytorch/pull/160483#issuecomment-3263597325))
This commit is contained in:
@ -394,14 +394,14 @@ def _while_loop_tests():
|
||||
([torch.randn(3, 3)], {"x": torch.randn(3, 3), "y": torch.randn(3, 3)}),
|
||||
),
|
||||
),
|
||||
"int_carry": (int_carry, (torch.randn(2, 3),)),
|
||||
"int_carry": (int_carry, (torch.randn(2, 3, requires_grad=True),)),
|
||||
"pytree_int_carry": (
|
||||
pytree_int_carry,
|
||||
(torch.randn(2, 3),),
|
||||
(torch.randn(2, 3, requires_grad=True),),
|
||||
),
|
||||
"const_and_symint_output": (
|
||||
const_and_symint_output,
|
||||
(torch.randn(2, 3),),
|
||||
(torch.randn(2, 3, requires_grad=True),),
|
||||
),
|
||||
}
|
||||
|
||||
@ -5513,35 +5513,69 @@ def forward(self, arg0_1):
|
||||
gm = backend.graphs[0]
|
||||
if torch._dynamo.config.inline_inbuilt_nn_modules:
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(gm.print_readable(print_output=False)),
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_iter_: "i64[]", L_x_: "f32[2, 2]", L_self_buffers_dec_: "i64[]", L_self_modules_linear_parameters_weight_: "f32[2, 2]", L_self_modules_linear_parameters_bias_: "f32[2]"):
|
||||
l_iter_ = L_iter_
|
||||
l_x_ = L_x_
|
||||
l_self_buffers_dec_ = L_self_buffers_dec_
|
||||
l_self_modules_linear_parameters_weight_ = L_self_modules_linear_parameters_weight_
|
||||
l_self_modules_linear_parameters_bias_ = L_self_modules_linear_parameters_bias_
|
||||
|
||||
cond_fn_0 = self.cond_fn_0
|
||||
body_fn_0 = self.body_fn_0
|
||||
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (l_iter_, l_x_), (l_self_buffers_dec_, l_self_modules_linear_parameters_bias_, l_self_modules_linear_parameters_weight_)); cond_fn_0 = body_fn_0 = l_iter_ = l_x_ = l_self_buffers_dec_ = l_self_modules_linear_parameters_bias_ = l_self_modules_linear_parameters_weight_ = None
|
||||
getitem: "i64[]" = while_loop[0]
|
||||
getitem_1: "f32[2, 2]" = while_loop[1]; while_loop = None
|
||||
return (getitem, getitem_1)
|
||||
|
||||
class cond_fn_0(torch.nn.Module):
|
||||
def forward(self, child: "i64[]", child_1: "f32[2, 2]", l_self_buffers_dec__cond_fn: "i64[]", l_self_modules_linear_parameters_bias__body_fn: "f32[2]", l_self_modules_linear_parameters_weight__body_fn: "f32[2, 2]"):
|
||||
sub: "i64[]" = child - l_self_buffers_dec__cond_fn; child = l_self_buffers_dec__cond_fn = None
|
||||
gt: "b8[]" = sub > 0; sub = None
|
||||
return gt
|
||||
|
||||
class body_fn_0(torch.nn.Module):
|
||||
def forward(self, child_2: "i64[]", child_3: "f32[2, 2]", l_self_buffers_dec__cond_fn: "i64[]", l_self_modules_linear_parameters_bias__body_fn: "f32[2]", l_self_modules_linear_parameters_weight__body_fn: "f32[2, 2]"):
|
||||
child: "i64[]" = child_2 - 1; child_2 = None
|
||||
child_4: "f32[2, 2]" = torch._C._nn.linear(child_3, l_self_modules_linear_parameters_weight__body_fn, l_self_modules_linear_parameters_bias__body_fn); child_3 = l_self_modules_linear_parameters_weight__body_fn = l_self_modules_linear_parameters_bias__body_fn = None
|
||||
return (child, child_4)
|
||||
""", # noqa: B950
|
||||
def forward(self, L_iter_ : torch.Tensor, L_x_ : torch.Tensor, L_self_buffers_dec_ : torch.Tensor, L_self_modules_linear_parameters_weight_ : torch.nn.parameter.Parameter, L_self_modules_linear_parameters_bias_ : torch.nn.parameter.Parameter):
|
||||
l_iter_ = L_iter_
|
||||
l_x_ = L_x_
|
||||
l_self_buffers_dec_ = L_self_buffers_dec_
|
||||
l_self_modules_linear_parameters_weight_ = L_self_modules_linear_parameters_weight_
|
||||
l_self_modules_linear_parameters_bias_ = L_self_modules_linear_parameters_bias_
|
||||
cond_fn_0 = self.cond_fn_0
|
||||
body_fn_0 = self.body_fn_0
|
||||
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (l_iter_, l_x_), (l_self_buffers_dec_, l_self_modules_linear_parameters_bias_, l_self_modules_linear_parameters_weight_)); cond_fn_0 = body_fn_0 = l_iter_ = l_x_ = l_self_buffers_dec_ = l_self_modules_linear_parameters_bias_ = l_self_modules_linear_parameters_weight_ = None
|
||||
getitem = while_loop[0]
|
||||
getitem_1 = while_loop[1]; while_loop = None
|
||||
return (getitem, getitem_1)""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
gm.cond_fn_0.code.strip(),
|
||||
"""\
|
||||
def forward(self, child : torch.Tensor, child_1 : torch.Tensor, l_self_buffers_dec__cond_fn, l_self_modules_linear_parameters_bias__body_fn, l_self_modules_linear_parameters_weight__body_fn):
|
||||
sub = child - l_self_buffers_dec__cond_fn; child = l_self_buffers_dec__cond_fn = None
|
||||
gt = sub > 0; sub = None
|
||||
return gt""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
gm.body_fn_0.code.strip(),
|
||||
"""\
|
||||
def forward(self, child_2 : torch.Tensor, child_3 : torch.Tensor, l_self_buffers_dec__cond_fn, l_self_modules_linear_parameters_bias__body_fn, l_self_modules_linear_parameters_weight__body_fn):
|
||||
child = child_2 - 1; child_2 = None
|
||||
child_4 = torch._C._nn.linear(child_3, l_self_modules_linear_parameters_weight__body_fn, l_self_modules_linear_parameters_bias__body_fn); child_3 = l_self_modules_linear_parameters_weight__body_fn = l_self_modules_linear_parameters_bias__body_fn = None
|
||||
return (child, child_4)""", # noqa: B950
|
||||
)
|
||||
else:
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, L_iter_ : torch.Tensor, L_x_ : torch.Tensor):
|
||||
l_iter_ = L_iter_
|
||||
l_x_ = L_x_
|
||||
l__self___dec = self.L__self___dec
|
||||
l__self___linear_weight = self.L__self___linear_weight
|
||||
l__self___linear_bias = self.L__self___linear_bias
|
||||
cond_fn_0 = self.cond_fn_0
|
||||
body_fn_0 = self.body_fn_0
|
||||
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (l_iter_, l_x_), (l__self___dec, l__self___linear_bias, l__self___linear_weight)); cond_fn_0 = body_fn_0 = l_iter_ = l_x_ = l__self___dec = l__self___linear_bias = l__self___linear_weight = None
|
||||
getitem = while_loop[0]
|
||||
getitem_1 = while_loop[1]; while_loop = None
|
||||
return (getitem, getitem_1)""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
gm.cond_fn_0.code.strip(),
|
||||
"""\
|
||||
def forward(self, l_iter_, l_x_, l__self___dec_cond_fn, l__self___linear_bias_body_fn, l__self___linear_weight_body_fn):
|
||||
sub = l_iter_ - l__self___dec_cond_fn; l_iter_ = l__self___dec_cond_fn = None
|
||||
gt = sub > 0; sub = None
|
||||
return gt""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
gm.body_fn_0.code.strip(),
|
||||
"""\
|
||||
def forward(self, l_iter_, l_x_, l__self___dec_cond_fn, l__self___linear_bias_body_fn, l__self___linear_weight_body_fn):
|
||||
child = l_iter_ - 1; l_iter_ = None
|
||||
child_1 = torch._C._nn.linear(l_x_, l__self___linear_weight_body_fn, l__self___linear_bias_body_fn); l_x_ = l__self___linear_weight_body_fn = l__self___linear_bias_body_fn = None
|
||||
return (child, child_1)""", # noqa: B950
|
||||
)
|
||||
|
||||
def test_while_loop_nested2_traced(self):
|
||||
@ -8077,7 +8111,7 @@ class GraphModule(torch.nn.Module):
|
||||
m, args = WHILE_LOOP_TESTS["pytree_int_carry"]
|
||||
dynamic_shapes = {"x": {0: torch.export.Dim("dim_x")}} if dynamic else None
|
||||
ep = self._check_export(m, args, strict=strict, dynamic_shapes=dynamic_shapes)
|
||||
if strict and dynamic and not TEST_WITH_CROSSREF:
|
||||
if strict and dynamic:
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(ep.module().print_readable(print_output=False)),
|
||||
"""\
|
||||
@ -8235,154 +8269,6 @@ class GraphModule(torch.nn.Module):
|
||||
self.assertEqual(compiled_out[1].size(0), 3)
|
||||
self.assertEqual(compiled_out, mod(x))
|
||||
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
def test_while_loop_autograd_simple(self):
|
||||
backend = torch._dynamo.testing.AotEagerAndRecordGraphs()
|
||||
|
||||
class ModEager(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(3, 3)
|
||||
|
||||
def forward(self, x):
|
||||
while x.sum() < 2:
|
||||
x = x * x + 1 + self.linear(x)
|
||||
return x
|
||||
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(3, 3)
|
||||
|
||||
def forward(self, x):
|
||||
def cond_fn(x):
|
||||
return x.sum() < 2
|
||||
|
||||
def body_fn(x):
|
||||
return x * x + 1 + self.linear(x)
|
||||
|
||||
return torch._higher_order_ops.while_loop(cond_fn, body_fn, (x,))
|
||||
|
||||
x = torch.randn(3, 3, requires_grad=True)
|
||||
x_clone = x.clone()
|
||||
mod = Mod()
|
||||
mod_eager = ModEager()
|
||||
# Copy weights from mod to mod_eager
|
||||
mod_eager.load_state_dict(mod.state_dict())
|
||||
compiled_out = torch.compile(mod, backend=backend, fullgraph=True)(x)
|
||||
exp_out = mod_eager(x_clone)
|
||||
compiled_out.sum().backward()
|
||||
exp_out.sum().backward()
|
||||
self.assertEqual(compiled_out, exp_out)
|
||||
eager_parameters = dict(mod_eager.named_parameters())
|
||||
compiled_parameters = dict(mod.named_parameters())
|
||||
for name, param in compiled_parameters.items():
|
||||
self.assertEqual(param, eager_parameters[name])
|
||||
self.assertEqual(param.grad, eager_parameters[name].grad)
|
||||
|
||||
self.assertEqual(
|
||||
len(
|
||||
backend.fw_graphs[0].graph.find_nodes(
|
||||
op="call_function",
|
||||
target=torch.ops.higher_order.while_loop_stack_output,
|
||||
)
|
||||
),
|
||||
1,
|
||||
)
|
||||
self.assertEqual(
|
||||
len(
|
||||
backend.bw_graphs[0].graph.find_nodes(
|
||||
op="call_function", target=torch.ops.higher_order.while_loop
|
||||
)
|
||||
),
|
||||
1,
|
||||
)
|
||||
if not TEST_WITH_CROSSREF:
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(backend.fw_graphs[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "f32[3, 3]", primals_2: "f32[3, 3]", primals_3: "f32[3]"):
|
||||
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
|
||||
while_loop_body_graph_0 = self.while_loop_body_graph_0
|
||||
while_loop_stack_output = torch.ops.higher_order.while_loop_stack_output(while_loop_cond_graph_0, while_loop_body_graph_0, (primals_1,), (primals_3, primals_2)); while_loop_cond_graph_0 = while_loop_body_graph_0 = None
|
||||
getitem: "f32[u2, 3, 3]" = while_loop_stack_output[0]; while_loop_stack_output = None
|
||||
select: "f32[3, 3]" = torch.ops.aten.select.int(getitem, 0, -1)
|
||||
unsqueeze: "f32[1, 3, 3]" = torch.ops.aten.unsqueeze.default(primals_1, 0); primals_1 = None
|
||||
slice_1: "f32[u2 - 1, 3, 3]" = torch.ops.aten.slice.Tensor(getitem, 0, 0, -1); getitem = None
|
||||
cat: "f32[u2, 3, 3]" = torch.ops.aten.cat.default([unsqueeze, slice_1]); unsqueeze = slice_1 = None
|
||||
return (select, primals_2, primals_3, cat)
|
||||
|
||||
class while_loop_cond_graph_0(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3]", arg2_1: "f32[3, 3]"):
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(arg0_1); arg0_1 = None
|
||||
lt: "b8[]" = torch.ops.aten.lt.Scalar(sum_1, 2); sum_1 = None
|
||||
return lt
|
||||
|
||||
class while_loop_body_graph_0(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3]", arg2_1: "f32[3, 3]"):
|
||||
mul: "f32[3, 3]" = torch.ops.aten.mul.Tensor(arg0_1, arg0_1)
|
||||
add: "f32[3, 3]" = torch.ops.aten.add.Tensor(mul, 1); mul = None
|
||||
t: "f32[3, 3]" = torch.ops.aten.t.default(arg2_1); arg2_1 = None
|
||||
addmm: "f32[3, 3]" = torch.ops.aten.addmm.default(arg1_1, arg0_1, t); arg1_1 = arg0_1 = t = None
|
||||
add_1: "f32[3, 3]" = torch.ops.aten.add.Tensor(add, addmm); add = addmm = None
|
||||
return (add_1,)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(backend.bw_graphs[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_2: "f32[3, 3]", primals_3: "f32[3]", cat: "f32[u2, 3, 3]", tangents_1: "f32[3, 3]"):
|
||||
zeros: "i64[]" = torch.ops.aten.zeros.default([], dtype = torch.int64, device = device(type='cpu'), pin_memory = False)
|
||||
zeros_like: "f32[3]" = torch.ops.aten.zeros_like.default(primals_3, pin_memory = False)
|
||||
zeros_like_1: "f32[3, 3]" = torch.ops.aten.zeros_like.default(primals_2, pin_memory = False)
|
||||
while_loop_cond_graph_1 = self.while_loop_cond_graph_1
|
||||
while_loop_body_graph_1 = self.while_loop_body_graph_1
|
||||
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_1, while_loop_body_graph_1, (zeros, tangents_1, zeros_like, zeros_like_1), (cat, primals_3, primals_2)); while_loop_cond_graph_1 = while_loop_body_graph_1 = zeros = tangents_1 = zeros_like = zeros_like_1 = cat = primals_3 = primals_2 = None
|
||||
getitem_2: "f32[3, 3]" = while_loop[1]
|
||||
getitem_3: "f32[3]" = while_loop[2]
|
||||
getitem_4: "f32[3, 3]" = while_loop[3]; while_loop = None
|
||||
return (getitem_2, getitem_4, getitem_3)
|
||||
|
||||
class while_loop_cond_graph_1(torch.nn.Module):
|
||||
def forward(self, arg0_1: "i64[]", arg1_1: "f32[3, 3]", arg2_1: "f32[3]", arg3_1: "f32[3, 3]", arg4_1: "f32[u2, 3, 3]", arg5_1: "f32[3]", arg6_1: "f32[3, 3]"):
|
||||
sym_size_int_1: "Sym(u2)" = torch.ops.aten.sym_size.int(arg4_1, 0); arg4_1 = None
|
||||
|
||||
lt: "b8[]" = torch.ops.aten.lt.Scalar(arg0_1, sym_size_int_1); arg0_1 = sym_size_int_1 = None
|
||||
return lt
|
||||
|
||||
class while_loop_body_graph_1(torch.nn.Module):
|
||||
def forward(self, arg0_1: "i64[]", arg1_1: "f32[3, 3]", arg2_1: "f32[3]", arg3_1: "f32[3, 3]", arg4_1: "f32[u2, 3, 3]", arg5_1: "f32[3]", arg6_1: "f32[3, 3]"):
|
||||
sym_size_int_1: "Sym(u2)" = torch.ops.aten.sym_size.int(arg4_1, 0)
|
||||
|
||||
rsub: "i64[]" = torch.ops.aten.rsub.Scalar(arg0_1, sym_size_int_1); sym_size_int_1 = None
|
||||
sub_1: "i64[]" = torch.ops.aten.sub.Tensor(rsub, 1); rsub = None
|
||||
_local_scalar_dense: "Sym(u9)" = torch.ops.aten._local_scalar_dense.default(sub_1); sub_1 = None
|
||||
select: "f32[3, 3]" = torch.ops.aten.select.int(arg4_1, 0, _local_scalar_dense); arg4_1 = _local_scalar_dense = None
|
||||
t: "f32[3, 3]" = torch.ops.aten.t.default(arg6_1); arg6_1 = None
|
||||
t_1: "f32[3, 3]" = torch.ops.aten.t.default(t); t = None
|
||||
mm: "f32[3, 3]" = torch.ops.aten.mm.default(arg1_1, t_1); t_1 = None
|
||||
t_2: "f32[3, 3]" = torch.ops.aten.t.default(arg1_1)
|
||||
mm_1: "f32[3, 3]" = torch.ops.aten.mm.default(t_2, select); t_2 = None
|
||||
t_3: "f32[3, 3]" = torch.ops.aten.t.default(mm_1); mm_1 = None
|
||||
sum_1: "f32[1, 3]" = torch.ops.aten.sum.dim_IntList(arg1_1, [0], True)
|
||||
view: "f32[3]" = torch.ops.aten.view.default(sum_1, [3]); sum_1 = None
|
||||
t_4: "f32[3, 3]" = torch.ops.aten.t.default(t_3); t_3 = None
|
||||
mul_4: "f32[3, 3]" = torch.ops.aten.mul.Tensor(arg1_1, select)
|
||||
mul_5: "f32[3, 3]" = torch.ops.aten.mul.Tensor(arg1_1, select); arg1_1 = select = None
|
||||
|
||||
add_7: "f32[3, 3]" = torch.ops.aten.add.Tensor(mm, mul_5); mm = mul_5 = None
|
||||
add_8: "f32[3, 3]" = torch.ops.aten.add.Tensor(add_7, mul_4); add_7 = mul_4 = None
|
||||
|
||||
add_9: "i64[]" = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None
|
||||
add_10: "f32[3]" = torch.ops.aten.add.Tensor(view, arg2_1); view = arg2_1 = None
|
||||
add_11: "f32[3, 3]" = torch.ops.aten.add.Tensor(t_4, arg3_1); t_4 = arg3_1 = None
|
||||
return (add_9, add_8, add_10, add_11)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
def test_input_output_alias(self):
|
||||
def fn(f, *args):
|
||||
return torch.cond(args[0].sum() > 0, f, f, args)
|
||||
|
@ -2172,19 +2172,8 @@ class AOTInductorTestsTemplate:
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
)
|
||||
|
||||
# mps doesn't support float64
|
||||
@skipIfMPS
|
||||
def test_while_loop_with_parameters(self):
|
||||
inputs = (
|
||||
torch.randn(
|
||||
(
|
||||
10,
|
||||
20,
|
||||
),
|
||||
dtype=torch.float64,
|
||||
device=self.device,
|
||||
),
|
||||
)
|
||||
inputs = (torch.randn((10, 20), device=self.device),)
|
||||
dim0_a = Dim("s0", min=2, max=1024)
|
||||
dynamic_shapes = {
|
||||
"c": {},
|
||||
|
@ -804,12 +804,8 @@ class WhileLoopModels:
|
||||
class InnerModel(torch.nn.Module):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
self.layer1 = torch.nn.Linear(
|
||||
20, 30, device=device, dtype=torch.float64
|
||||
)
|
||||
self.layer2 = torch.nn.Linear(
|
||||
30, 20, device=device, dtype=torch.float64
|
||||
)
|
||||
self.layer1 = torch.nn.Linear(20, 30, device=device)
|
||||
self.layer2 = torch.nn.Linear(30, 20, device=device)
|
||||
|
||||
def forward(self, c, x):
|
||||
return c - 1, self.layer2(self.layer1(x - 2)) * 3.14
|
||||
@ -1029,7 +1025,7 @@ class WhileLoopModels:
|
||||
e = torch.nonzero(b).size(0)
|
||||
|
||||
def cond_fn(c, a, b):
|
||||
return c + d + e + a.shape[0] - b.shape[0] < 10
|
||||
return d + e + a.shape[0] - b.shape[0] < 10
|
||||
|
||||
def body_fn(c, a, b):
|
||||
return c + 1, a + e, b + d
|
||||
@ -1112,32 +1108,31 @@ class WhileLoopModels:
|
||||
|
||||
class WhileLoopTests(TestCase):
|
||||
def _run_test(
|
||||
self, model, inputs, device, dynamic=False, num_counters=1, autograd=False
|
||||
self,
|
||||
model,
|
||||
inputs,
|
||||
device,
|
||||
dynamic=False,
|
||||
num_counters=1,
|
||||
):
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor")
|
||||
import copy
|
||||
|
||||
if not autograd:
|
||||
for p in model.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
compiled_model = copy.deepcopy(model)
|
||||
compiled_fn = torch.compile(backend=cnt, fullgraph=True)(compiled_model)
|
||||
compiled_model = torch.compile(backend=cnt, fullgraph=True)(model)
|
||||
|
||||
inputs = pytree.tree_map(lambda t: t.to(device=device), inputs)
|
||||
input_sets = [inputs]
|
||||
|
||||
def mark_first_dim_dyn(inp):
|
||||
torch._dynamo.mark_dynamic(inp, 0)
|
||||
|
||||
if dynamic:
|
||||
|
||||
def mark_first_dim_dyn(inp):
|
||||
torch._dynamo.mark_dynamic(inp, 0)
|
||||
|
||||
pytree.tree_map(mark_first_dim_dyn, input_sets)
|
||||
|
||||
def tile_fn(inp):
|
||||
# tile every first dim 5x
|
||||
tiling = [5] + [1] * (inp.ndim - 1)
|
||||
t = torch.tile(inp, tiling)
|
||||
# mark every first dim as dynamic
|
||||
torch._dynamo.mark_dynamic(inp, 0)
|
||||
return t
|
||||
|
||||
larger_inputs = pytree.tree_map(tile_fn, inputs)
|
||||
@ -1154,78 +1149,24 @@ class WhileLoopTests(TestCase):
|
||||
)
|
||||
unflat_inputs = pytree.tree_unflatten(flat, inp_spec)
|
||||
inputs_with_counters = counters + unflat_inputs
|
||||
|
||||
def process_inputs(inp):
|
||||
inp = inp.clone()
|
||||
if dynamic:
|
||||
mark_first_dim_dyn(inp)
|
||||
|
||||
if autograd and inp.dtype.is_floating_point:
|
||||
inp.requires_grad_(True)
|
||||
return inp
|
||||
|
||||
cloned_inputs = pytree.tree_map(process_inputs, inputs_with_counters)
|
||||
cloned_inputs2 = pytree.tree_map(process_inputs, inputs_with_counters)
|
||||
|
||||
result = model(*cloned_inputs)
|
||||
result_compiled = compiled_fn(*cloned_inputs2)
|
||||
cloned_inputs = pytree.tree_map(
|
||||
lambda t: t.clone(), inputs_with_counters
|
||||
)
|
||||
result = model(*inputs_with_counters)
|
||||
with torch.no_grad():
|
||||
result_compiled = compiled_model(*inputs_with_counters)
|
||||
# inputs must not be mutated
|
||||
torch.testing.assert_close(cloned_inputs, inputs_with_counters)
|
||||
torch.testing.assert_close(
|
||||
result, result_compiled, atol=1e-4, rtol=1e-4
|
||||
)
|
||||
|
||||
if autograd and any(
|
||||
pytree.tree_map_only(
|
||||
torch.Tensor, lambda t: t.requires_grad, cloned_inputs
|
||||
)
|
||||
):
|
||||
result_loss = loss_fn(pytree.tree_flatten(result)[0])
|
||||
compiled_loss = loss_fn(pytree.tree_flatten(result_compiled)[0])
|
||||
self.assertTrue(
|
||||
not torch.isnan(result_loss) and not torch.isinf(compiled_loss)
|
||||
)
|
||||
self.assertTrue(
|
||||
not torch.isnan(compiled_loss)
|
||||
and not torch.isinf(compiled_loss)
|
||||
)
|
||||
|
||||
self.assertEqual(result_loss, compiled_loss)
|
||||
|
||||
result_loss.backward()
|
||||
compiled_loss.backward()
|
||||
|
||||
model_parameters = dict(model.named_parameters())
|
||||
compiled_parameters = dict(compiled_model.named_parameters())
|
||||
for name, param in model_parameters.items():
|
||||
self.assertEqual(param, compiled_parameters[name])
|
||||
self.assertEqual(
|
||||
param.grad,
|
||||
compiled_parameters[name].grad,
|
||||
atol=1e-4,
|
||||
rtol=1e-4,
|
||||
)
|
||||
|
||||
for inp1, inp2 in zip(
|
||||
pytree.tree_flatten(cloned_inputs)[0],
|
||||
pytree.tree_flatten(cloned_inputs2)[0],
|
||||
):
|
||||
if inp1.requires_grad:
|
||||
self.assertEqual(
|
||||
inp1.grad,
|
||||
inp2.grad,
|
||||
atol=1e-4,
|
||||
rtol=1e-4,
|
||||
)
|
||||
|
||||
self.assertEqual(cnt.frame_count, 1, "only one compilation expected")
|
||||
|
||||
@requires_gpu
|
||||
@parametrize("device", ["cpu", GPU_TYPE])
|
||||
@parametrize("dynamic", [False, True])
|
||||
@parametrize("autograd", [False, True])
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_while_loop_simple_control_flow(self, device, dynamic, autograd):
|
||||
def test_while_loop_simple_control_flow(self, device, dynamic):
|
||||
# while_loop control flow without nesting
|
||||
self._run_test(
|
||||
model=WhileLoopModels.Simple(),
|
||||
@ -1235,15 +1176,12 @@ class WhileLoopTests(TestCase):
|
||||
),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
autograd=autograd,
|
||||
)
|
||||
|
||||
@requires_gpu
|
||||
@parametrize("device", ["cpu", GPU_TYPE])
|
||||
@parametrize("dynamic", [False, True])
|
||||
@parametrize("autograd", [False, True])
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_while_loop_nested_control_flow(self, device, dynamic, autograd):
|
||||
def test_while_loop_nested_control_flow(self, device, dynamic):
|
||||
# while_loop control flow with nesting
|
||||
self._run_test(
|
||||
model=WhileLoopModels.Nested(),
|
||||
@ -1254,15 +1192,12 @@ class WhileLoopTests(TestCase):
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
num_counters=2,
|
||||
autograd=autograd,
|
||||
)
|
||||
|
||||
@requires_gpu
|
||||
@parametrize("device", ["cpu", GPU_TYPE])
|
||||
@parametrize("dynamic", [False, True])
|
||||
@parametrize("autograd", [False, True])
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_while_loop_with_outer_code(self, device, dynamic, autograd):
|
||||
def test_while_loop_with_outer_code(self, device, dynamic):
|
||||
# while_loop control flow with outer code
|
||||
self._run_test(
|
||||
model=WhileLoopModels.OuterCode(),
|
||||
@ -1272,22 +1207,18 @@ class WhileLoopTests(TestCase):
|
||||
),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
autograd=autograd,
|
||||
)
|
||||
|
||||
@requires_gpu
|
||||
@parametrize("device", ["cpu", GPU_TYPE])
|
||||
@parametrize("dynamic", [False, True])
|
||||
@parametrize("autograd", [False, True])
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_while_loop_with_parameters(self, device, dynamic, autograd):
|
||||
def test_while_loop_with_parameters(self, device, dynamic):
|
||||
# while_loop control flow with parameters
|
||||
self._run_test(
|
||||
model=WhileLoopModels.Parameters(device),
|
||||
inputs=(torch.randn(10, 20, dtype=torch.float64),),
|
||||
inputs=(torch.randn(10, 20),),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
autograd=autograd,
|
||||
)
|
||||
|
||||
@requires_gpu
|
||||
@ -1295,9 +1226,7 @@ class WhileLoopTests(TestCase):
|
||||
# dynamic=True doesn't work now due to
|
||||
# https://github.com/pytorch/pytorch/issues/123596
|
||||
@parametrize("dynamic", [False])
|
||||
@parametrize("autograd", [False, True])
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_while_loop_with_outer_buffers(self, device, dynamic, autograd):
|
||||
def test_while_loop_with_outer_buffers(self, device, dynamic):
|
||||
# while_loop control flow with outer code
|
||||
self._run_test(
|
||||
model=WhileLoopModels.OuterBuffers(),
|
||||
@ -1307,15 +1236,13 @@ class WhileLoopTests(TestCase):
|
||||
),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
autograd=autograd,
|
||||
)
|
||||
|
||||
@requires_gpu
|
||||
@parametrize("device", ["cpu", GPU_TYPE])
|
||||
# dynamic=True doesn't work due to we haven't handle lifted symbols
|
||||
@parametrize("dynamic", [True, False])
|
||||
@parametrize("autograd", [False, True])
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_while_loop_with_pytree_inputs(self, device, dynamic, autograd):
|
||||
def test_while_loop_with_pytree_inputs(self, device, dynamic):
|
||||
self._run_test(
|
||||
model=WhileLoopModels.PytreeCarry(),
|
||||
inputs=(
|
||||
@ -1326,15 +1253,12 @@ class WhileLoopTests(TestCase):
|
||||
),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
autograd=autograd,
|
||||
)
|
||||
|
||||
@requires_gpu
|
||||
@parametrize("device", ["cpu", GPU_TYPE])
|
||||
@parametrize("dynamic", [True, False])
|
||||
@parametrize("autograd", [False, True])
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_while_loop_with_data_dependent_ops(self, device, dynamic, autograd):
|
||||
def test_while_loop_with_data_dependent_ops(self, device, dynamic):
|
||||
with torch._dynamo.config.patch(
|
||||
{
|
||||
"capture_dynamic_output_shape_ops": True,
|
||||
@ -1350,15 +1274,12 @@ class WhileLoopTests(TestCase):
|
||||
),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
autograd=autograd,
|
||||
)
|
||||
|
||||
@requires_gpu
|
||||
@parametrize("device", ["cpu", GPU_TYPE])
|
||||
@parametrize("dynamic", [True, False])
|
||||
@parametrize("autograd", [False, True])
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_while_loop_with_data_dependent_in_out(self, device, dynamic, autograd):
|
||||
def test_while_loop_with_data_dependent_in_out(self, device, dynamic):
|
||||
with torch._dynamo.config.patch(
|
||||
{
|
||||
"capture_dynamic_output_shape_ops": True,
|
||||
@ -1375,7 +1296,6 @@ class WhileLoopTests(TestCase):
|
||||
),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
autograd=autograd,
|
||||
)
|
||||
|
||||
@parametrize("dynamic", [True, False])
|
||||
@ -1436,8 +1356,7 @@ class WhileLoopTests(TestCase):
|
||||
@torch._dynamo.config.patch(
|
||||
{"capture_scalar_outputs": True, "capture_dynamic_output_shape_ops": True}
|
||||
)
|
||||
@parametrize("autograd", [False, True])
|
||||
def test_while_loop_with_unbacked_symint_closure(self, device, dynamic, autograd):
|
||||
def test_while_loop_with_unbacked_symint_closure(self, device, dynamic):
|
||||
self._run_test(
|
||||
model=WhileLoopModels.UnbackedSymIntClosure(),
|
||||
inputs=(
|
||||
@ -1446,7 +1365,6 @@ class WhileLoopTests(TestCase):
|
||||
),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
autograd=autograd,
|
||||
)
|
||||
|
||||
@requires_gpu
|
||||
@ -1481,11 +1399,10 @@ class WhileLoopTests(TestCase):
|
||||
@requires_gpu
|
||||
@parametrize("device", ["cpu", GPU_TYPE])
|
||||
@parametrize("dynamic", [True, False])
|
||||
@parametrize("autograd", [False, True])
|
||||
@torch._dynamo.config.patch(
|
||||
{"capture_scalar_outputs": True, "capture_dynamic_output_shape_ops": True}
|
||||
)
|
||||
def test_while_loop_with_sym_expr_cond(self, device, dynamic, autograd):
|
||||
def test_while_loop_with_sym_expr_cond(self, device, dynamic):
|
||||
self._run_test(
|
||||
model=WhileLoopModels.SymExprCond(),
|
||||
inputs=(
|
||||
@ -1494,27 +1411,22 @@ class WhileLoopTests(TestCase):
|
||||
),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
autograd=autograd,
|
||||
)
|
||||
|
||||
@requires_gpu
|
||||
@parametrize("device", ["cpu", GPU_TYPE])
|
||||
@parametrize("dynamic", [True, False])
|
||||
@parametrize("autograd", [False, True])
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_while_loop_with_conv(self, device, dynamic, autograd):
|
||||
def test_while_loop_with_conv(self, device, dynamic):
|
||||
self._run_test(
|
||||
model=WhileLoopModels.Conv(device),
|
||||
inputs=(torch.randn(2, 4, 4, 4, dtype=torch.float64),),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
autograd=autograd,
|
||||
)
|
||||
|
||||
@requires_gpu
|
||||
@parametrize("device", ["cpu", GPU_TYPE])
|
||||
@parametrize("dynamic", [True, False])
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_while_loop_stack_output_simple(self, device, dynamic):
|
||||
self._run_test(
|
||||
model=WhileLoopModels.WhileLoopStackOutputSimple(device),
|
||||
@ -2177,6 +2089,16 @@ class MapTests(TestCase):
|
||||
self.assertEqual(result, result_compiled)
|
||||
|
||||
if autograd:
|
||||
|
||||
def loss_fn(result) -> torch.Tensor:
|
||||
flat_results, _ = pytree.tree_flatten(result)
|
||||
return sum(
|
||||
[
|
||||
torch.sqrt(torch.pow(res.sum() / res.max(), 2)).sum()
|
||||
for res in flat_results
|
||||
]
|
||||
)
|
||||
|
||||
loss_fn(result).backward()
|
||||
loss_fn(result_exp).backward()
|
||||
loss_fn(result_compiled).backward()
|
||||
|
@ -12,9 +12,6 @@ from torch._higher_order_ops.utils import (
|
||||
autograd_not_implemented,
|
||||
check_input_alias_and_mutation_return_outputs,
|
||||
check_meta_consistency,
|
||||
fill_none_with_masks,
|
||||
filter_with_masks,
|
||||
materialize_as_graph,
|
||||
reenter_make_fx,
|
||||
validate_subgraph_args_types,
|
||||
)
|
||||
@ -329,16 +326,9 @@ def while_loop_dense(
|
||||
return carried_vals
|
||||
|
||||
|
||||
@while_loop_op.py_autograd_impl
|
||||
def while_loop_autograd(cond_fn, body_fn, operands, additional_inputs):
|
||||
return WhileLoopAutogradOp.apply(
|
||||
cond_fn,
|
||||
body_fn,
|
||||
len(operands),
|
||||
len(additional_inputs),
|
||||
*operands,
|
||||
*additional_inputs,
|
||||
)
|
||||
while_loop_op.py_autograd_impl(
|
||||
autograd_not_implemented(while_loop_op, deferred_error=True)
|
||||
)
|
||||
|
||||
|
||||
def _find_or_create_fake_mode() -> FakeTensorMode:
|
||||
@ -644,268 +634,6 @@ class WhileLoopStackOutputOp(HigherOrderOperator):
|
||||
return super().__call__(cond_fn, body_fn, carried_inputs, additional_inputs)
|
||||
|
||||
|
||||
# Note [while_loop autograd]
|
||||
# Consider wthe following while_loop that can be visualized as:
|
||||
# additional_inputs
|
||||
# ┌─────┬─────┼─────┬─────┐
|
||||
# | | | | |
|
||||
# ↓ ↓ ↓ ↓ ↓
|
||||
# x ──→ y0 ─→ y1 ─→ y2 ─→ y3 ─→ y4
|
||||
#
|
||||
# The bacwkard can be visualized as follows:
|
||||
#
|
||||
# g_additional_inputs
|
||||
# ┌──────┬──────┼──────┬──────┐
|
||||
# | | | | |
|
||||
# | | | | |
|
||||
# gx <── gy0 <─ gy1 <─ gy2 <─ gy3 <─ gy4
|
||||
#
|
||||
# We can compute gx using chain rule:
|
||||
#
|
||||
# gx = gy0 * bw(y0, x),
|
||||
#
|
||||
# where gy0 denotes the graident of loss with respect to y0, and bw(y0, x) denotes the graident of y0 with
|
||||
# respect to x. Note that bw can be computed from forward body_fn easily using torch.autograd.grad.
|
||||
# We could substitute the unknowns gy0, gy1, ..., with chain rule until gy4:
|
||||
#
|
||||
# gx = gy1 * bw(y1, y0) * bw(y0, x)
|
||||
# = gy2 * bw(y2, y1) * bw(y1, y0) * bw(y0, x)
|
||||
# = ...
|
||||
# = gy4 * bw(y4, y3) * bw(y3, y2) * bw(y2, y1) * bw(y1, y0) * bw(y0, x)
|
||||
#
|
||||
# since gy4 is the graient of the final output, which is given as the backward input, we've got a formula
|
||||
# to compute gx. A abbr for the formula is: gy4 * bw43210x
|
||||
#
|
||||
# In a similar way, we can compute g_additional_inputs using chain rule:
|
||||
#
|
||||
# g_additional_inputs = gy0 * bw(y0, addi) + gy1 * bw(y1, addi) + gy2 * bw(y2, addi) + ... + gy4 * bw(y4, addi)
|
||||
#
|
||||
# Notice that gy0 = gy4 * bw43210, gy1 = gy4 * bw4321 etc, we now also get a formula for g_additional_inputs.
|
||||
#
|
||||
# Implementation:
|
||||
# The idea of implementation is to construct a while_loop to calculate both gx and g_additional_inputs.
|
||||
# Specifically, we can implement the backward of while_loop with as follows:
|
||||
#
|
||||
# def cond_fn(idx, grad_carries, grad_additional_inputs, fw_additional_inputs, fw_inps):
|
||||
# return idx < fw_inps.size(0)
|
||||
#
|
||||
# def body_fn(idx, grad_carries, grad_additional_inputs, fw_additional_inputs, fw_inps):
|
||||
# reversed_idx = fw_inps.size(0) - 1 - idx
|
||||
# next_grad_carry, next_grad_additional_inputs = bw(fw_inps[reversed_idx], fw_additional_inputs, grad_carries)
|
||||
# return idx + 1, next_grad_carry, next_grad_additional_inputs + grad_additional_inputs
|
||||
#
|
||||
# idx = 0
|
||||
# init_grad_carries = grads
|
||||
# init_grad_additional_inputs = torch.zeros_like(g_additioanl_inputs)
|
||||
# fw_inps = torch.cat([ctx.fw_carried_inputs, fw_outputs[:-1]])
|
||||
# while_loop(cond_fn, body_fn, (idx, init_grad_carries, init_grad_additional_inputs,), (fw_additional_inputs, fw_inps))
|
||||
|
||||
|
||||
class WhileLoopAutogradOp(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
cond_fn,
|
||||
body_fn,
|
||||
num_carried_inputs,
|
||||
num_additional_inputs,
|
||||
*carries_and_inputs,
|
||||
):
|
||||
from torch._higher_order_ops.scan import split_into_chunks
|
||||
|
||||
carries, additional_inputs = split_into_chunks(
|
||||
carries_and_inputs, [num_carried_inputs, num_additional_inputs]
|
||||
)
|
||||
with torch._C._AutoDispatchBelowAutograd():
|
||||
fw_outputs = while_loop_stack_output_op(
|
||||
cond_fn, body_fn, carries, additional_inputs
|
||||
)
|
||||
|
||||
assert not hasattr(ctx, "fw_cond_fn")
|
||||
assert not hasattr(ctx, "fw_body_fn")
|
||||
assert not hasattr(ctx, "carries")
|
||||
assert not hasattr(ctx, "additional_inputs")
|
||||
assert not hasattr(ctx, "fw_outputs")
|
||||
ctx.fw_cond_fn = cond_fn
|
||||
ctx.fw_body_fn = body_fn
|
||||
ctx.carries = carries
|
||||
ctx.additional_inputs = additional_inputs
|
||||
ctx.fw_outputs = fw_outputs
|
||||
loop_count = None
|
||||
for out in fw_outputs:
|
||||
if isinstance(out, torch.Tensor):
|
||||
if loop_count is not None:
|
||||
assert out.size(0) == loop_count
|
||||
else:
|
||||
loop_count = out.size(0)
|
||||
assert loop_count is not None
|
||||
|
||||
# Remove the loop_count from pending_fresh_unbacked_symbols
|
||||
# because it's not part of forward output and it's impossible
|
||||
# to bind it to a proxy in forward graph anyways.
|
||||
if (
|
||||
isinstance(loop_count, torch.SymInt)
|
||||
and (shape_env := loop_count.node.shape_env)
|
||||
and loop_count in shape_env.pending_fresh_unbacked_symbols
|
||||
):
|
||||
shape_env.pending_fresh_unbacked_symbols.remove(loop_count)
|
||||
|
||||
# Even when body function is not executed, we clone and unsqueeze the input
|
||||
# to avoid the aliasing, therefore loop_count is always >= 1
|
||||
torch._check(loop_count >= 1)
|
||||
# We snapshot the dispatch keys in forward for materializing the
|
||||
# the bw_graph in backward.
|
||||
ctx._fw_include_key_set = torch._C._dispatch_tls_local_include_set()
|
||||
ctx._fw_exclude_key_set = torch._C._dispatch_tls_local_exclude_set()
|
||||
assert len(fw_outputs) > 0, "fw_outputs shouldn't be empty"
|
||||
# Only the last of the output fw_outputs need to be returned
|
||||
return tuple(ckp[-1] for ckp in fw_outputs)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grads):
|
||||
from torch._higher_order_ops.cond import create_bw_fn
|
||||
from torch._higher_order_ops.scan import split_into_chunks
|
||||
|
||||
# set up single step bw fn
|
||||
bw_body_fn = create_bw_fn(ctx.fw_body_fn, ctx.carries + ctx.additional_inputs)
|
||||
# Note [Handle inputs that're not differentiable]
|
||||
# When a forward input is non-differentiable e.g. a symint or an integer tensor, their gradients
|
||||
# will be None. However, we don't want to return None in the subgraph because this complicates the
|
||||
# inductor codegen, where we need to do a non-unform treatment for None and tensors.
|
||||
# So we set up masks and filter the None gradients so that only tensors are returned from each step.
|
||||
carries_tensor_masks = [
|
||||
True if isinstance(t, torch.Tensor) and t.dtype.is_floating_point else False
|
||||
for t in ctx.carries
|
||||
]
|
||||
additional_inputs_tensor_masks = [
|
||||
True if isinstance(t, torch.Tensor) and t.dtype.is_floating_point else False
|
||||
for t in ctx.additional_inputs
|
||||
]
|
||||
|
||||
init_idx = torch.zeros((), dtype=torch.int64)
|
||||
init_grad_carries = filter_with_masks(grads, carries_tensor_masks) # type: ignore[arg-type]
|
||||
init_grad_additional_inputs = tuple(
|
||||
torch.zeros_like(t)
|
||||
for need_keep, t in zip(
|
||||
additional_inputs_tensor_masks, ctx.additional_inputs
|
||||
)
|
||||
if need_keep
|
||||
)
|
||||
# We need to the forward inputs to each iteration to compute the backward
|
||||
# which is the concatenation of first iteraiton input i.e. ctx.carries and all iterations's
|
||||
# output except the last iteration.
|
||||
fw_carries = [
|
||||
torch.cat([carry.unsqueeze(0), carries[:-1]])
|
||||
for carry, carries in zip(ctx.carries, ctx.fw_outputs)
|
||||
]
|
||||
for fw_carry, carry in zip(fw_carries, ctx.carries):
|
||||
fw_carry.requires_grad_(carry.requires_grad)
|
||||
|
||||
_, spec = pytree.tree_flatten(
|
||||
(
|
||||
init_idx,
|
||||
init_grad_carries,
|
||||
init_grad_additional_inputs,
|
||||
ctx.fw_outputs,
|
||||
ctx.additional_inputs,
|
||||
)
|
||||
)
|
||||
|
||||
def cond_fn(*flat_args):
|
||||
(
|
||||
idx,
|
||||
grad_carries,
|
||||
grad_additional_inputs,
|
||||
fw_carries,
|
||||
additional_inputs,
|
||||
) = pytree.tree_unflatten(flat_args, spec)
|
||||
assert isinstance(fw_carries[0], torch.Tensor), fw_carries[0]
|
||||
# excluding the last iteration's output
|
||||
return idx < fw_carries[0].size(0)
|
||||
|
||||
def body_fn(*flat_args):
|
||||
(
|
||||
idx,
|
||||
grad_carries,
|
||||
grad_additional_inputs,
|
||||
fw_carries,
|
||||
additional_inputs,
|
||||
) = pytree.tree_unflatten(flat_args, spec)
|
||||
reversed_idx = fw_carries[0].size(0) - idx - 1
|
||||
selected_fw_carries = [
|
||||
ckp.select(0, reversed_idx.item()) for ckp in fw_carries
|
||||
]
|
||||
cur_grad_carries, cur_grad_additional_inputs = split_into_chunks(
|
||||
bw_body_fn(*selected_fw_carries, *additional_inputs, *grad_carries),
|
||||
[len(ctx.carries), len(ctx.additional_inputs)],
|
||||
)
|
||||
assert all(isinstance(t, torch.Tensor) for t in cur_grad_carries)
|
||||
cur_grad_carries_tensors = filter_with_masks(
|
||||
cur_grad_carries, carries_tensor_masks
|
||||
)
|
||||
cur_grad_additional_inputs_tensors = filter_with_masks(
|
||||
cur_grad_additional_inputs, additional_inputs_tensor_masks
|
||||
)
|
||||
return (
|
||||
idx + 1,
|
||||
*cur_grad_carries_tensors,
|
||||
*(
|
||||
cur_grad + grad
|
||||
for cur_grad, grad in zip(
|
||||
cur_grad_additional_inputs_tensors, grad_additional_inputs
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
args_single_step_bw = (
|
||||
init_idx,
|
||||
*init_grad_carries,
|
||||
*init_grad_additional_inputs,
|
||||
*fw_carries,
|
||||
*ctx.additional_inputs,
|
||||
)
|
||||
|
||||
cond_gm = materialize_as_graph(
|
||||
cond_fn,
|
||||
args_single_step_bw,
|
||||
ctx._fw_include_key_set,
|
||||
ctx._fw_exclude_key_set,
|
||||
force_enable_grad=True,
|
||||
)
|
||||
|
||||
body_gm = materialize_as_graph(
|
||||
body_fn,
|
||||
args_single_step_bw,
|
||||
ctx._fw_include_key_set,
|
||||
ctx._fw_exclude_key_set,
|
||||
force_enable_grad=True,
|
||||
)
|
||||
|
||||
_, final_grad_carries, final_grad_additional_inputs = split_into_chunks(
|
||||
while_loop_op(
|
||||
cond_gm,
|
||||
body_gm,
|
||||
(
|
||||
init_idx,
|
||||
*init_grad_carries,
|
||||
*init_grad_additional_inputs,
|
||||
),
|
||||
(*fw_carries, *ctx.additional_inputs),
|
||||
),
|
||||
[1, len(init_grad_carries), len(init_grad_additional_inputs)],
|
||||
)
|
||||
return (
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
*fill_none_with_masks(final_grad_carries, carries_tensor_masks),
|
||||
*fill_none_with_masks(
|
||||
final_grad_additional_inputs, additional_inputs_tensor_masks
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
while_loop_stack_output_op = WhileLoopStackOutputOp()
|
||||
|
||||
while_loop_stack_output_op.py_impl(DispatchKey.CompositeExplicitAutograd)(
|
||||
|
@ -2345,6 +2345,7 @@ class _MakefxTracer:
|
||||
|
||||
insert_deferred_runtime_asserts(t, fake_mode.shape_env, "reenter_make_fx")
|
||||
t.recompile()
|
||||
|
||||
# TODO: kind of a bad way to do it, should maybe figure out a better way
|
||||
if self.tracing_mode == "symbolic":
|
||||
assert self.fake_tensor_mode is not None
|
||||
|
Reference in New Issue
Block a user