mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 08:24:57 +08:00
C++ API parity: LogSigmoid
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27060 Test Plan: Imported from OSS Differential Revision: D17682404 Pulled By: pbelevich fbshipit-source-id: d60d64cd4caf1f56a2e05c516f91321d46ec9624
This commit is contained in:
committed by
Facebook Github Bot
parent
17c672e704
commit
2cc1e69cc9
@ -1099,6 +1099,23 @@ TEST_F(ModulesTest, LeakyReLU) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, LogSigmoid) {
|
||||
const auto size = 3;
|
||||
LogSigmoid model;
|
||||
auto x = torch::linspace(-10.0, 10.0, size * size * size);
|
||||
x.resize_({size, size, size}).set_requires_grad(true);
|
||||
auto y = model(x);
|
||||
torch::Tensor s = y.sum();
|
||||
|
||||
s.backward();
|
||||
ASSERT_EQ(s.ndimension(), 0);
|
||||
|
||||
ASSERT_EQ(y.ndimension(), 3);
|
||||
ASSERT_EQ(y.sizes(), torch::IntArrayRef({size, size, size}));
|
||||
auto y_exp = torch::log(torch::ones_like(x)/(torch::ones_like(x) + torch::exp(torch::neg(x))));
|
||||
ASSERT_TRUE(torch::allclose(y, y_exp, 1e-4, 1e-7));
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, PrettyPrintIdentity) {
|
||||
ASSERT_EQ(c10::str(Identity()), "torch::nn::Identity()");
|
||||
}
|
||||
@ -1349,3 +1366,7 @@ TEST_F(ModulesTest, PrettyPrintLeakyReLU) {
|
||||
LeakyReLUOptions().negative_slope(0.42).inplace(true))),
|
||||
"torch::nn::LeakyReLU(negative_slope=0.42, inplace=true)");
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, PrettyPrintLogSigmoid) {
|
||||
ASSERT_EQ(c10::str(LogSigmoid()), "torch::nn::LogSigmoid()");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user