diff --git a/torch/_functorch/_aot_autograd/graph_capture.py b/torch/_functorch/_aot_autograd/graph_capture.py index 91af2933cc28..132cf335b387 100644 --- a/torch/_functorch/_aot_autograd/graph_capture.py +++ b/torch/_functorch/_aot_autograd/graph_capture.py @@ -468,12 +468,16 @@ def aot_dispatch_autograd_graph( # a fake tensor. Unlikely. # See Note: [Fake Modules and AOTAutograd] torch._dynamo.utils.assert_no_fake_params_or_buffers(fx_g) + + # Have to copy before eliminate_dead_code otherwise the + # fw node match might be erased + copy_fwd_metadata_to_bw_nodes(fx_g) + fx_g.graph.eliminate_dead_code() if not aot_config.disable_functionalization: # There should be *NO* mutating ops in the graph at this point. assert_functional_graph(fx_g.graph) - copy_fwd_metadata_to_bw_nodes(fx_g) fx_g.recompile() # TODO: in AOTAutograd, we create metadata like _indices_of_inps_to_detach to detect