[annotate] Annotate bw nodes before eliminate dead code (#165782)

Fixes https://github.com/pytorch/torchtitan/pull/1907

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165782
Approved by: https://github.com/SherlockNoMad
This commit is contained in:
Shangdi Yu
2025-10-18 01:54:27 +00:00
committed by PyTorch MergeBot
parent de3da77cf7
commit cf3a787bbc

View File

@ -468,12 +468,16 @@ def aot_dispatch_autograd_graph(
# a fake tensor. Unlikely.
# See Note: [Fake Modules and AOTAutograd]
torch._dynamo.utils.assert_no_fake_params_or_buffers(fx_g)
# Have to copy before eliminate_dead_code otherwise the
# fw node match might be erased
copy_fwd_metadata_to_bw_nodes(fx_g)
fx_g.graph.eliminate_dead_code()
if not aot_config.disable_functionalization:
# There should be *NO* mutating ops in the graph at this point.
assert_functional_graph(fx_g.graph)
copy_fwd_metadata_to_bw_nodes(fx_g)
fx_g.recompile()
# TODO: in AOTAutograd, we create metadata like _indices_of_inps_to_detach to detect