mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33851 Rationale and context described in #33828. Script to reproduce the move: https://gist.github.com/suo/16cbefaaeb67ca5a7c6caffd49b7f6e9 ghstack-source-id: 99079645 Test Plan: Make sure CI passes Reviewed By: jamesr66a Differential Revision: D20133869 fbshipit-source-id: 390e9241a9c85366d9005c492ac31f10aa96488e
33 lines
1014 B
C++
33 lines
1014 B
C++
#include "test/cpp/jit/test_base.h"
|
|
#include "test/cpp/jit/test_utils.h"
|
|
#include "torch/csrc/jit/runtime/graph_executor.h"
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
void testGraphExecutor() {
|
|
constexpr int batch_size = 4;
|
|
constexpr int input_size = 256;
|
|
|
|
int hidden_size = 2 * input_size;
|
|
|
|
auto input = at::randn({batch_size, input_size}, at::kCUDA);
|
|
auto hx = at::randn({batch_size, hidden_size}, at::kCUDA);
|
|
auto cx = at::randn({batch_size, hidden_size}, at::kCUDA);
|
|
auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCUDA));
|
|
auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCUDA));
|
|
|
|
auto g = build_lstm();
|
|
GraphExecutor executor(g);
|
|
auto stack = createStack({input, hx, cx, w_ih, w_hh});
|
|
executor.run(stack);
|
|
ASSERT_EQ(stack.size(), 2);
|
|
at::Tensor r0, r1;
|
|
std::tie(r0, r1) = lstm(input, hx, cx, w_ih, w_hh);
|
|
ASSERT_TRUE(almostEqual(stack[0].toTensor(), r0));
|
|
ASSERT_TRUE(almostEqual(stack[1].toTensor(), r1));
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|