[functorch] fix AOTAutograd tutorial (#87415)

It was raising asserts previously
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87415
Approved by: https://github.com/Chillee
This commit is contained in:
Richard Zou
2022-10-20 15:40:03 -07:00
committed by PyTorch MergeBot
parent b1cf377cce
commit 13ab819356

View File

@ -126,7 +126,10 @@
"aot_print_fn = aot_function(fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn)\n",
"\n",
"# Run the aot_print_fn once to trigger the compilation and print the graphs\n",
"res = aot_print_fn(a, b, c, d).sum().backward()\n",
"cloned_inputs = [x.clone().detach().requires_grad_(True) for x in (a, b, c, d)]\n",
"cloned_a, cloned_b, cloned_c, cloned_d = cloned_inputs\n",
"res = aot_print_fn(cloned_a, cloned_b, cloned_c, cloned_d)\n",
"res.sum().backward()\n",
"assert torch.allclose(ref, res)"
]
},
@ -300,6 +303,9 @@
"source": [
"from functorch.compile import min_cut_rematerialization_partition\n",
"\n",
"# Zero out the gradients so we can do a comparison later\n",
"a.grad, b.grad, c.grad, d.grad = (None,) * 4\n",
"\n",
"# Lets set up the partitioner. Also set the fwd and bwd compilers to the printer function that we used earlier.\n",
"# This will show us how the recomputation has modified the graph.\n",
"aot_fn = aot_function(fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn, partition_fn=min_cut_rematerialization_partition)\n",