Add output_node util function to fx.Graph (#139770)

Summary: A util function for access output node for FX graph

Test Plan: OSS CI

Differential Revision: D65486457

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139770
Approved by: https://github.com/ezyang, https://github.com/Chillee
This commit is contained in:
Sherlock Huang
2024-11-07 18:54:59 +00:00
committed by PyTorch MergeBot
parent ee54dfb64d
commit 071d48c56e
2 changed files with 7 additions and 1 deletions

View File

@ -1,6 +1,6 @@
torch.fx._symbolic_trace.ProxyableClassMeta []
torch.fx._symbolic_trace.Tracer ['call_module', 'create_arg', 'create_args_for_root', 'get_fresh_qualname', 'getattr', 'is_leaf_module', 'path_of_module', 'trace']
torch.fx.graph.Graph ['call_function', 'call_method', 'call_module', 'create_node', 'eliminate_dead_code', 'erase_node', 'find_nodes', 'get_attr', 'graph_copy', 'inserting_after', 'inserting_before', 'lint', 'node_copy', 'nodes', 'on_generate_code', 'output', 'owning_module', 'placeholder', 'print_tabular', 'process_inputs', 'process_outputs', 'python_code', 'set_codegen']
torch.fx.graph.Graph ['call_function', 'call_method', 'call_module', 'create_node', 'eliminate_dead_code', 'erase_node', 'find_nodes', 'get_attr', 'graph_copy', 'inserting_after', 'inserting_before', 'lint', 'node_copy', 'nodes', 'on_generate_code', 'output', 'output_node', 'owning_module', 'placeholder', 'print_tabular', 'process_inputs', 'process_outputs', 'python_code', 'set_codegen']
torch.fx.graph.PythonCode []
torch.fx.graph_module.GraphModule ['add_submodule', 'code', 'delete_all_unused_submodules', 'delete_submodule', 'graph', 'print_readable', 'recompile', 'to_folder']
torch.fx.immutable_collections.immutable_dict ['clear', 'pop', 'popitem', 'setdefault', 'update']

View File

@ -1026,6 +1026,12 @@ class Graph:
"""
return _node_list(self)
@compatibility(is_backward_compatible=False)
def output_node(self) -> Node:
output_node = next(iter(reversed(self.nodes)))
assert output_node.op == "output"
return output_node
@compatibility(is_backward_compatible=False)
def find_nodes(
self, *, op: str, target: Optional["Target"] = None, sort: bool = True