big cpp test reorg (#24801)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/24801

This is to fix the ODR-violations in fbcode static builds, which have been broken for several months.

This PR is unfortunately quite large, but the changes are only mechanical:
1. Tests defined in header files -> tests defined in cpp files
2. Remove the `torch::jit::testing` namespace -> `torch::jit`.
3. Single `test.h` file that aggregates all tests.
4. Separate out files for gtest and python versions of the tests instead of using a build flag
5. Add a readme for how to add a new test, and explaining a bit about why the cpp tests are the way they are.

Test Plan: Imported from OSS

Differential Revision: D16878605

Pulled By: suo

fbshipit-source-id: 27b5c077dadd990a5f74e25d01731f9c1f491603
This commit is contained in:
Michael Suo
2019-08-18 16:46:56 -07:00
committed by Facebook Github Bot
parent 85564c1456
commit dfdb86a595
37 changed files with 328 additions and 293 deletions

View File

@ -0,0 +1,34 @@
#include "test/cpp/jit/test_base.h"
#include "test/cpp/jit/test_utils.h"
#include "torch/csrc/jit/graph_executor.h"
namespace torch {
namespace jit {
void testGraphExecutor() {
constexpr int batch_size = 4;
constexpr int input_size = 256;
int hidden_size = 2 * input_size;
auto v = [](at::Tensor t) { return autograd::make_variable(t, false); };
auto input = at::randn({batch_size, input_size}, at::kCUDA);
auto hx = at::randn({batch_size, hidden_size}, at::kCUDA);
auto cx = at::randn({batch_size, hidden_size}, at::kCUDA);
auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCUDA));
auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCUDA));
auto g = build_lstm();
GraphExecutor executor(g);
auto stack = createStack({v(input), v(hx), v(cx), v(w_ih), v(w_hh)});
executor.run(stack);
ASSERT_EQ(stack.size(), 2);
at::Tensor r0, r1;
std::tie(r0, r1) = lstm(input, hx, cx, w_ih, w_hh);
ASSERT_TRUE(almostEqual(stack[0].toTensor(), v(r0)));
ASSERT_TRUE(almostEqual(stack[1].toTensor(), v(r1)));
}
} // namespace jit
} // namespace torch