mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/24801 This is to fix the ODR-violations in fbcode static builds, which have been broken for several months. This PR is unfortunately quite large, but the changes are only mechanical: 1. Tests defined in header files -> tests defined in cpp files 2. Remove the `torch::jit::testing` namespace -> `torch::jit`. 3. Single `test.h` file that aggregates all tests. 4. Separate out files for gtest and python versions of the tests instead of using a build flag 5. Add a readme for how to add a new test, and explaining a bit about why the cpp tests are the way they are. Test Plan: Imported from OSS Differential Revision: D16878605 Pulled By: suo fbshipit-source-id: 27b5c077dadd990a5f74e25d01731f9c1f491603
86 lines
2.1 KiB
C++
86 lines
2.1 KiB
C++
#include "test/cpp/jit/test_base.h"
|
|
#include "test/cpp/jit/test_utils.h"
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
void testAttributes() {
|
|
Graph g;
|
|
auto one = attr::alpha;
|
|
auto two = attr::device;
|
|
auto three = attr::end;
|
|
auto four = attr::perm;
|
|
Node* n = g.create(Symbol::fromQualString("foo::bar"));
|
|
Node& attr = *n;
|
|
attr.f_(one, 3.4)->i_(two, 5)->s_(three, "what");
|
|
ASSERT_EQ(attr.f(one), 3.4);
|
|
ASSERT_EQ(attr.s(three), "what");
|
|
ASSERT_EQ(attr.i(two), 5);
|
|
attr.s_(one, "no");
|
|
ASSERT_EQ(attr.s(one), "no");
|
|
ASSERT_TRUE(attr.hasAttribute(three));
|
|
ASSERT_TRUE(!attr.hasAttribute(four));
|
|
attr.ss_(two, {"hi", "now"});
|
|
ASSERT_EQ(attr.ss(two).at(1), "now");
|
|
|
|
Node* n2 = g.create(Symbol::fromQualString("foo::baz"));
|
|
Node& attr2 = *n2;
|
|
attr2.copyAttributes(attr);
|
|
ASSERT_EQ(attr2.s(one), "no");
|
|
attr2.f_(one, 5);
|
|
ASSERT_EQ(attr.s(one), "no");
|
|
ASSERT_EQ(attr2.f(one), 5);
|
|
}
|
|
|
|
void testBlocks() {
|
|
auto g = std::make_shared<Graph>();
|
|
// auto g = *graph;
|
|
auto a = Var::asNewInput(*g, "a");
|
|
auto b = Var::asNewInput(*g, "b");
|
|
auto c = a + b;
|
|
auto r =
|
|
g->appendNode(g->create(prim::If, {Var::asNewInput(*g, "c").value()}));
|
|
auto then_block = r->addBlock();
|
|
auto else_block = r->addBlock();
|
|
{
|
|
WithInsertPoint guard(then_block);
|
|
auto t = c + c;
|
|
then_block->registerOutput(t.value());
|
|
}
|
|
{
|
|
WithInsertPoint guard(else_block);
|
|
auto d = b + c;
|
|
auto e = d + c;
|
|
else_block->registerOutput(e.value());
|
|
}
|
|
g->registerOutput((Var(r->output()) + c).value());
|
|
g->lint();
|
|
testing::FileCheck()
|
|
.check("add")
|
|
->check("prim::If")
|
|
->check("block0")
|
|
->check("aten::add")
|
|
->check("block1")
|
|
->check_count("aten::add", 3)
|
|
->run(*g);
|
|
r->eraseBlock(0);
|
|
testing::FileCheck()
|
|
.check("add")
|
|
->check("prim::If")
|
|
->check("block0")
|
|
->check_not("block")
|
|
->run(*g);
|
|
g->lint();
|
|
// test recursive copy of blocks works
|
|
auto g2 = g->copy();
|
|
testing::FileCheck()
|
|
.check("add")
|
|
->check("prim::If")
|
|
->check("block0")
|
|
->check_not("block")
|
|
->run(*g2);
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|