C++ API parity: Tanh

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27536

Test Plan: Imported from OSS

Differential Revision: D17835411

Pulled By: pbelevich

fbshipit-source-id: c8984aec2f4bae48ff901fafc8c53a4122192ac5
This commit is contained in:
Pavel Belevich
2019-10-13 06:32:32 -07:00
committed by Facebook Github Bot
parent 27027a4804
commit 2750ea25b2
3 changed files with 42 additions and 0 deletions

View File

@ -1368,6 +1368,15 @@ TEST_F(ModulesTest, Softsign) {
ASSERT_TRUE(torch::allclose(y, y_exp));
}
TEST_F(ModulesTest, Tanh) {
Tanh model;
auto x = torch::randn(100) * 10;
auto y_exp = (x.exp() - (-x).exp()) / (x.exp() + (-x).exp());
auto y = model(x);
ASSERT_TRUE(torch::allclose(y, y_exp));
}
TEST_F(ModulesTest, PrettyPrintIdentity) {
ASSERT_EQ(c10::str(Identity()), "torch::nn::Identity()");
}
@ -1726,3 +1735,7 @@ TEST_F(ModulesTest, PrettyPrintSoftshrink) {
TEST_F(ModulesTest, PrettyPrintSoftsign) {
ASSERT_EQ(c10::str(Softsign()), "torch::nn::Softsign()");
}
TEST_F(ModulesTest, PrettyPrintTanh) {
ASSERT_EQ(c10::str(Tanh()), "torch::nn::Tanh()");
}