Add C++ nn::Identity (#26713)

Summary:
**Summary**:
Adds `torch::nn::Identity` module support for the C++ API.

**Issue**: https://github.com/pytorch/pytorch/issues/25883

**Reviewer**: yf225
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26713

Differential Revision: D17550982

Pulled By: yf225

fbshipit-source-id: f24483846e82d5d276d77a1a0c50884f3bc05112
This commit is contained in:
jon-tow
2019-09-24 11:27:26 -07:00
committed by facebook-github-bot
parent c0c2921a06
commit 5e5b9a9321
3 changed files with 47 additions and 0 deletions

View File

@ -199,6 +199,18 @@ TEST_F(ModulesTest, AvgPool3d) {
ASSERT_EQ(y.sizes(), torch::IntArrayRef({2, 2, 2, 2}));
}
TEST_F(ModulesTest, Identity) {
Identity identity;
auto input = torch::tensor({{1, 3, 4}, {2, 3, 4}}, torch::requires_grad());
auto output = identity->forward(input);
auto expected = torch::tensor({{1, 3, 4}, {2, 3, 4}}, torch::kFloat);
auto s = output.sum();
s.backward();
ASSERT_TRUE(torch::equal(output, expected));
ASSERT_TRUE(torch::equal(input.grad(), torch::ones_like(input)));
}
TEST_F(ModulesTest, Linear) {
Linear model(5, 2);
auto x = torch::randn({10, 5}, torch::requires_grad());
@ -475,6 +487,10 @@ TEST_F(ModulesTest, PairwiseDistance) {
ASSERT_EQ(input1.sizes(), input1.grad().sizes());
}
TEST_F(ModulesTest, PrettyPrintIdentity) {
ASSERT_EQ(c10::str(Identity()), "torch::nn::Identity()");
}
TEST_F(ModulesTest, PrettyPrintLinear) {
ASSERT_EQ(
c10::str(Linear(3, 4)), "torch::nn::Linear(in=3, out=4, with_bias=true)");