Files
pytorch/torch/csrc/jit/passes/utils/subgraph_utils.h
Karl Ostmo 8f0603b128 C++ changes toward libtorch and libcaffe2 unification (#19554)
Summary:
* adds TORCH_API and AT_CUDA_API in places
* refactor code generation Python logic to separate
  caffe2/torch outputs
* fix hip and asan
* remove profiler_cuda from hip
* fix gcc warnings for enums
* Fix PythonOp::Kind
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19554

Differential Revision: D15082727

Pulled By: kostmo

fbshipit-source-id: 83a8a99717f025ab44b29608848928d76b3147a4
2019-04-26 01:38:10 -07:00

38 lines
1.0 KiB
C++

#pragma once
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
namespace torch {
namespace jit {
// Utilities for dealing with nodes that contain subgraphs.
//
// They handle the complexity of editing inputs/outputs as you merge nodes in
// and out of subgraphs.
namespace SubgraphUtils {
// Create a new subgraph node that contains only `n`. The new subgraph will have
// `subgraphKind` as its type.
//
// `n` is destroyed.
//
// Returns the new subgraph node.
TORCH_API Node* createSingletonSubgraph(Node* n, Symbol subgraphKind);
// Merge a node into a subgraph node. If `toMerge` is also a subgraph, the
// subgraphs are merged.
// `toMerge` is destroyed.
TORCH_API void mergeNodeIntoSubgraph(Node* toMerge, Node* subgraphNode);
// Move nodes from a subgraph node to the outer graph.
// `subgraphNode` is destroyed.
TORCH_API void unmergeSubgraph(Node* subgraphNode);
// Convenience function
std::shared_ptr<Graph> getSubgraph(Node* n);
} // namespace SubgraphUtils
} // namespace jit
} // namespace torch