From a5c93a6899c657832944cd2eeb5069449e28dbea Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Fri, 10 May 2024 10:32:24 -0700 Subject: [PATCH] Speed up _extract_graph_with_inputs_outputs (#125937) _extract_graph_with_inputs_outputs() does membership testing on the input nodes but often that collection is a list so the test is O(n). Ensure it's a set before looping over all the nodes. This change speeds up the internal repro (D57090987) by about 18%: before: ``` 708.88user 15.86system 12:16.19elapsed 98%CPU (0avgtext+0avgdata 12898628maxresident)k 0inputs+91968outputs (3major+3532970minor)pagefaults 0swaps ``` after: ``` 583.39user 15.98system 10:10.11elapsed 98%CPU (0avgtext+0avgdata 12895108maxresident)k 0inputs+87488outputs (4major+3374582minor)pagefaults 0swaps ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/125937 Approved by: https://github.com/oulgen, https://github.com/anijain2305 --- torch/_functorch/partitioners.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index 4ac5596cde50..ba549e5bd6e2 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -96,7 +96,10 @@ def _extract_graph_with_inputs_outputs(joint_graph, inputs, outputs): env[node] = new_node for node in joint_graph.nodes: - if node in inputs: + if node in env: + # Node must be one of our inputs. (Any member of env which wasn't an + # input to start must have been created by this loop and won't be in + # joint_graph.nodes). continue elif node.op == "placeholder": env[node] = InvalidNode