mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[annotate] Annotate bw nodes before eliminate dead code (#165782)
Fixes https://github.com/pytorch/torchtitan/pull/1907 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165782 Approved by: https://github.com/SherlockNoMad
This commit is contained in:
committed by
PyTorch MergeBot
parent
de3da77cf7
commit
cf3a787bbc
@ -468,12 +468,16 @@ def aot_dispatch_autograd_graph(
|
||||
# a fake tensor. Unlikely.
|
||||
# See Note: [Fake Modules and AOTAutograd]
|
||||
torch._dynamo.utils.assert_no_fake_params_or_buffers(fx_g)
|
||||
|
||||
# Have to copy before eliminate_dead_code otherwise the
|
||||
# fw node match might be erased
|
||||
copy_fwd_metadata_to_bw_nodes(fx_g)
|
||||
|
||||
fx_g.graph.eliminate_dead_code()
|
||||
if not aot_config.disable_functionalization:
|
||||
# There should be *NO* mutating ops in the graph at this point.
|
||||
assert_functional_graph(fx_g.graph)
|
||||
|
||||
copy_fwd_metadata_to_bw_nodes(fx_g)
|
||||
fx_g.recompile()
|
||||
|
||||
# TODO: in AOTAutograd, we create metadata like _indices_of_inps_to_detach to detect
|
||||
|
Reference in New Issue
Block a user