mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] fix AOTAutograd tutorial (#87415)
It was raising asserts previously Pull Request resolved: https://github.com/pytorch/pytorch/pull/87415 Approved by: https://github.com/Chillee
This commit is contained in:
committed by
PyTorch MergeBot
parent
b1cf377cce
commit
13ab819356
@ -126,7 +126,10 @@
|
||||
"aot_print_fn = aot_function(fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn)\n",
|
||||
"\n",
|
||||
"# Run the aot_print_fn once to trigger the compilation and print the graphs\n",
|
||||
"res = aot_print_fn(a, b, c, d).sum().backward()\n",
|
||||
"cloned_inputs = [x.clone().detach().requires_grad_(True) for x in (a, b, c, d)]\n",
|
||||
"cloned_a, cloned_b, cloned_c, cloned_d = cloned_inputs\n",
|
||||
"res = aot_print_fn(cloned_a, cloned_b, cloned_c, cloned_d)\n",
|
||||
"res.sum().backward()\n",
|
||||
"assert torch.allclose(ref, res)"
|
||||
]
|
||||
},
|
||||
@ -300,6 +303,9 @@
|
||||
"source": [
|
||||
"from functorch.compile import min_cut_rematerialization_partition\n",
|
||||
"\n",
|
||||
"# Zero out the gradients so we can do a comparison later\n",
|
||||
"a.grad, b.grad, c.grad, d.grad = (None,) * 4\n",
|
||||
"\n",
|
||||
"# Lets set up the partitioner. Also set the fwd and bwd compilers to the printer function that we used earlier.\n",
|
||||
"# This will show us how the recomputation has modified the graph.\n",
|
||||
"aot_fn = aot_function(fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn, partition_fn=min_cut_rematerialization_partition)\n",
|
||||
|
Reference in New Issue
Block a user