Speed up _extract_graph_with_inputs_outputs (#125937)

_extract_graph_with_inputs_outputs() does membership testing on the input nodes but often that collection is a list so the test is O(n).  Ensure it's a set before looping over all the nodes.

This change speeds up the internal repro (D57090987) by about 18%:
before:
```
708.88user 15.86system 12:16.19elapsed 98%CPU (0avgtext+0avgdata 12898628maxresident)k
0inputs+91968outputs (3major+3532970minor)pagefaults 0swaps
```
after:
```
583.39user 15.98system 10:10.11elapsed 98%CPU (0avgtext+0avgdata 12895108maxresident)k
0inputs+87488outputs (4major+3374582minor)pagefaults 0swaps
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125937
Approved by: https://github.com/oulgen, https://github.com/anijain2305
This commit is contained in:
Aaron Orenstein
2024-05-10 10:32:24 -07:00
committed by PyTorch MergeBot
parent 4457cd9a30
commit a5c93a6899

View File

@ -96,7 +96,10 @@ def _extract_graph_with_inputs_outputs(joint_graph, inputs, outputs):
env[node] = new_node
for node in joint_graph.nodes:
if node in inputs:
if node in env:
# Node must be one of our inputs. (Any member of env which wasn't an
# input to start must have been created by this loop and won't be in
# joint_graph.nodes).
continue
elif node.op == "placeholder":
env[node] = InvalidNode