mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-26 16:44:54 +08:00
big cpp test reorg (#24801)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/24801 This is to fix the ODR-violations in fbcode static builds, which have been broken for several months. This PR is unfortunately quite large, but the changes are only mechanical: 1. Tests defined in header files -> tests defined in cpp files 2. Remove the `torch::jit::testing` namespace -> `torch::jit`. 3. Single `test.h` file that aggregates all tests. 4. Separate out files for gtest and python versions of the tests instead of using a build flag 5. Add a readme for how to add a new test, and explaining a bit about why the cpp tests are the way they are. Test Plan: Imported from OSS Differential Revision: D16878605 Pulled By: suo fbshipit-source-id: 27b5c077dadd990a5f74e25d01731f9c1f491603
This commit is contained in:
committed by
Facebook Github Bot
parent
85564c1456
commit
dfdb86a595
34
test/cpp/jit/test_graph_executor.cpp
Normal file
34
test/cpp/jit/test_graph_executor.cpp
Normal file
@ -0,0 +1,34 @@
|
||||
#include "test/cpp/jit/test_base.h"
|
||||
#include "test/cpp/jit/test_utils.h"
|
||||
#include "torch/csrc/jit/graph_executor.h"
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
void testGraphExecutor() {
|
||||
constexpr int batch_size = 4;
|
||||
constexpr int input_size = 256;
|
||||
|
||||
int hidden_size = 2 * input_size;
|
||||
|
||||
auto v = [](at::Tensor t) { return autograd::make_variable(t, false); };
|
||||
|
||||
auto input = at::randn({batch_size, input_size}, at::kCUDA);
|
||||
auto hx = at::randn({batch_size, hidden_size}, at::kCUDA);
|
||||
auto cx = at::randn({batch_size, hidden_size}, at::kCUDA);
|
||||
auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCUDA));
|
||||
auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCUDA));
|
||||
|
||||
auto g = build_lstm();
|
||||
GraphExecutor executor(g);
|
||||
auto stack = createStack({v(input), v(hx), v(cx), v(w_ih), v(w_hh)});
|
||||
executor.run(stack);
|
||||
ASSERT_EQ(stack.size(), 2);
|
||||
at::Tensor r0, r1;
|
||||
std::tie(r0, r1) = lstm(input, hx, cx, w_ih, w_hh);
|
||||
ASSERT_TRUE(almostEqual(stack[0].toTensor(), v(r0)));
|
||||
ASSERT_TRUE(almostEqual(stack[1].toTensor(), v(r1)));
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
Reference in New Issue
Block a user