mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Speed up _extract_graph_with_inputs_outputs (#125937)
_extract_graph_with_inputs_outputs() does membership testing on the input nodes but often that collection is a list so the test is O(n). Ensure it's a set before looping over all the nodes. This change speeds up the internal repro (D57090987) by about 18%: before: ``` 708.88user 15.86system 12:16.19elapsed 98%CPU (0avgtext+0avgdata 12898628maxresident)k 0inputs+91968outputs (3major+3532970minor)pagefaults 0swaps ``` after: ``` 583.39user 15.98system 10:10.11elapsed 98%CPU (0avgtext+0avgdata 12895108maxresident)k 0inputs+87488outputs (4major+3374582minor)pagefaults 0swaps ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/125937 Approved by: https://github.com/oulgen, https://github.com/anijain2305
This commit is contained in:
committed by
PyTorch MergeBot
parent
4457cd9a30
commit
a5c93a6899
@ -96,7 +96,10 @@ def _extract_graph_with_inputs_outputs(joint_graph, inputs, outputs):
|
||||
env[node] = new_node
|
||||
|
||||
for node in joint_graph.nodes:
|
||||
if node in inputs:
|
||||
if node in env:
|
||||
# Node must be one of our inputs. (Any member of env which wasn't an
|
||||
# input to start must have been created by this loop and won't be in
|
||||
# joint_graph.nodes).
|
||||
continue
|
||||
elif node.op == "placeholder":
|
||||
env[node] = InvalidNode
|
||||
|
Reference in New Issue
Block a user