mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 17:24:59 +08:00
Add C++ nn::Identity (#26713)
Summary: **Summary**: Adds `torch::nn::Identity` module support for the C++ API. **Issue**: https://github.com/pytorch/pytorch/issues/25883 **Reviewer**: yf225 Pull Request resolved: https://github.com/pytorch/pytorch/pull/26713 Differential Revision: D17550982 Pulled By: yf225 fbshipit-source-id: f24483846e82d5d276d77a1a0c50884f3bc05112
This commit is contained in:
committed by
facebook-github-bot
parent
c0c2921a06
commit
5e5b9a9321
@ -199,6 +199,18 @@ TEST_F(ModulesTest, AvgPool3d) {
|
||||
ASSERT_EQ(y.sizes(), torch::IntArrayRef({2, 2, 2, 2}));
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, Identity) {
|
||||
Identity identity;
|
||||
auto input = torch::tensor({{1, 3, 4}, {2, 3, 4}}, torch::requires_grad());
|
||||
auto output = identity->forward(input);
|
||||
auto expected = torch::tensor({{1, 3, 4}, {2, 3, 4}}, torch::kFloat);
|
||||
auto s = output.sum();
|
||||
s.backward();
|
||||
|
||||
ASSERT_TRUE(torch::equal(output, expected));
|
||||
ASSERT_TRUE(torch::equal(input.grad(), torch::ones_like(input)));
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, Linear) {
|
||||
Linear model(5, 2);
|
||||
auto x = torch::randn({10, 5}, torch::requires_grad());
|
||||
@ -475,6 +487,10 @@ TEST_F(ModulesTest, PairwiseDistance) {
|
||||
ASSERT_EQ(input1.sizes(), input1.grad().sizes());
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, PrettyPrintIdentity) {
|
||||
ASSERT_EQ(c10::str(Identity()), "torch::nn::Identity()");
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, PrettyPrintLinear) {
|
||||
ASSERT_EQ(
|
||||
c10::str(Linear(3, 4)), "torch::nn::Linear(in=3, out=4, with_bias=true)");
|
||||
|
||||
Reference in New Issue
Block a user