Files
pytorch/test/cpp/jit/test_ir.cpp
Michael Suo dfdb86a595 big cpp test reorg (#24801)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/24801

This is to fix the ODR-violations in fbcode static builds, which have been broken for several months.

This PR is unfortunately quite large, but the changes are only mechanical:
1. Tests defined in header files -> tests defined in cpp files
2. Remove the `torch::jit::testing` namespace -> `torch::jit`.
3. Single `test.h` file that aggregates all tests.
4. Separate out files for gtest and python versions of the tests instead of using a build flag
5. Add a readme for how to add a new test, and explaining a bit about why the cpp tests are the way they are.

Test Plan: Imported from OSS

Differential Revision: D16878605

Pulled By: suo

fbshipit-source-id: 27b5c077dadd990a5f74e25d01731f9c1f491603
2019-08-18 16:49:56 -07:00

86 lines
2.1 KiB
C++

#include "test/cpp/jit/test_base.h"
#include "test/cpp/jit/test_utils.h"
namespace torch {
namespace jit {
void testAttributes() {
Graph g;
auto one = attr::alpha;
auto two = attr::device;
auto three = attr::end;
auto four = attr::perm;
Node* n = g.create(Symbol::fromQualString("foo::bar"));
Node& attr = *n;
attr.f_(one, 3.4)->i_(two, 5)->s_(three, "what");
ASSERT_EQ(attr.f(one), 3.4);
ASSERT_EQ(attr.s(three), "what");
ASSERT_EQ(attr.i(two), 5);
attr.s_(one, "no");
ASSERT_EQ(attr.s(one), "no");
ASSERT_TRUE(attr.hasAttribute(three));
ASSERT_TRUE(!attr.hasAttribute(four));
attr.ss_(two, {"hi", "now"});
ASSERT_EQ(attr.ss(two).at(1), "now");
Node* n2 = g.create(Symbol::fromQualString("foo::baz"));
Node& attr2 = *n2;
attr2.copyAttributes(attr);
ASSERT_EQ(attr2.s(one), "no");
attr2.f_(one, 5);
ASSERT_EQ(attr.s(one), "no");
ASSERT_EQ(attr2.f(one), 5);
}
void testBlocks() {
auto g = std::make_shared<Graph>();
// auto g = *graph;
auto a = Var::asNewInput(*g, "a");
auto b = Var::asNewInput(*g, "b");
auto c = a + b;
auto r =
g->appendNode(g->create(prim::If, {Var::asNewInput(*g, "c").value()}));
auto then_block = r->addBlock();
auto else_block = r->addBlock();
{
WithInsertPoint guard(then_block);
auto t = c + c;
then_block->registerOutput(t.value());
}
{
WithInsertPoint guard(else_block);
auto d = b + c;
auto e = d + c;
else_block->registerOutput(e.value());
}
g->registerOutput((Var(r->output()) + c).value());
g->lint();
testing::FileCheck()
.check("add")
->check("prim::If")
->check("block0")
->check("aten::add")
->check("block1")
->check_count("aten::add", 3)
->run(*g);
r->eraseBlock(0);
testing::FileCheck()
.check("add")
->check("prim::If")
->check("block0")
->check_not("block")
->run(*g);
g->lint();
// test recursive copy of blocks works
auto g2 = g->copy();
testing::FileCheck()
.check("add")
->check("prim::If")
->check("block0")
->check_not("block")
->run(*g2);
}
} // namespace jit
} // namespace torch