Files
pytorch/test/cpp/jit/test_subgraph_utils.cpp
Michael Suo dfdb86a595 big cpp test reorg (#24801)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/24801

This is to fix the ODR-violations in fbcode static builds, which have been broken for several months.

This PR is unfortunately quite large, but the changes are only mechanical:
1. Tests defined in header files -> tests defined in cpp files
2. Remove the `torch::jit::testing` namespace -> `torch::jit`.
3. Single `test.h` file that aggregates all tests.
4. Separate out files for gtest and python versions of the tests instead of using a build flag
5. Add a readme for how to add a new test, and explaining a bit about why the cpp tests are the way they are.

Test Plan: Imported from OSS

Differential Revision: D16878605

Pulled By: suo

fbshipit-source-id: 27b5c077dadd990a5f74e25d01731f9c1f491603
2019-08-18 16:49:56 -07:00

42 lines
1.1 KiB
C++

#include "test/cpp/jit/test_base.h"
#include "test/cpp/jit/test_utils.h"
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
#include "torch/csrc/jit/passes/utils/subgraph_utils.h"
namespace torch {
namespace jit {
void testSubgraphUtils() {
auto graph = build_lstm();
EliminateCommonSubexpression(graph);
std::vector<Node*> originalNodes(
graph->nodes().begin(), graph->nodes().end());
// Merge everything into a single subgraph
bool first = true;
Node* subgraph;
for (auto it = graph->nodes().rbegin(); it != graph->nodes().rend();) {
if (first) {
subgraph = SubgraphUtils::createSingletonSubgraph(
*it, prim::DifferentiableGraph);
it = ++subgraph->reverseIterator();
first = false;
}
SubgraphUtils::mergeNodeIntoSubgraph(*it, subgraph);
it = ++subgraph->reverseIterator();
}
// Unmerge and compare with original node listing
SubgraphUtils::unmergeSubgraph(subgraph);
EliminateCommonSubexpression(graph);
std::vector<Node*> newNodes(graph->nodes().begin(), graph->nodes().end());
ASSERT_EQ(originalNodes.size(), newNodes.size());
}
} // namespace jit
} // namespace torch