Remove uses of deleted operations (#139447)

resolves: https://github.com/pytorch/pytorch/issues/138721

Summary:

Delete the uses of deleted nodes. The double for-loop is icky here, but N should
be pretty small and removing it requires refactoring the datastructures
involved, which is a bigger endeavor.

Test Plan:

Normal test coverage should be sufficient. There were a couple of spots in the
scheduler code that didn't check users being deleted, so I'll run a perf test to see
what impact that has, and to make sure N^2 doesn't affect compile times.

Perf:
https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Tue%2C%2029%20Oct%202024%2017%3A41%3A36%20GMT&stopTime=Tue%2C%2005%20Nov%202024%2018%3A41%3A36%20GMT&granularity=hour&suite=torchbench&mode=inference&dtype=bfloat16&deviceName=cuda%20(a100)&lBranch=exclamaforte/prune-deleted-users&lCommit=5cb1aa6f7d8a52acdae0c7cf36b8c2d536d7f0d1&rBranch=main&rCommit=f4ee5a243dbb31e6310e5632b1c87898b299df2c
off of nov4 nightly

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139447
Approved by: https://github.com/eellison
This commit is contained in:
Gabriel Ferns
2024-11-08 22:21:50 +00:00
committed by PyTorch MergeBot
parent 347f96061f
commit 95198f8299

View File

@ -2174,7 +2174,12 @@ class Scheduler:
# dead code
log.debug("removed dead operation: %s", node.get_name())
V.graph.removed_operations.add(node.get_name())
for read in node.read_writes.reads:
if read.name in self.name_to_buf:
users = self.name_to_buf[read.name].users
self.name_to_buf[read.name].users = [
u for u in users if u.node.get_name() != node.get_name()
]
self.nodes = list(reversed(updated_nodes))
# Prune any WeakDeps no longer needed