mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 00:14:54 +08:00
C++ API parity: PReLU
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27429 Test Plan: Imported from OSS Differential Revision: D17835412 Pulled By: pbelevich fbshipit-source-id: e678d5920dad1293bb0ba3de28e2da3087d19bde
This commit is contained in:
committed by
Facebook Github Bot
parent
0fbbc7acb4
commit
1fec1441a1
@ -1189,6 +1189,29 @@ TEST_F(ModulesTest, Softmax) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, PReLU) {
|
||||
const auto num_parameters = 42;
|
||||
const auto init = 0.42;
|
||||
|
||||
PReLU model {PReLUOptions().num_parameters(num_parameters).init(init)};
|
||||
|
||||
ASSERT_EQ(model->weight.sizes(), torch::IntArrayRef({num_parameters}));
|
||||
ASSERT_TRUE(torch::allclose(model->weight,
|
||||
torch::full(num_parameters, init)));
|
||||
|
||||
const auto x = torch::rand({100, num_parameters}) * 200 - 100;
|
||||
const auto y = model(x);
|
||||
const auto s = y.sum();
|
||||
|
||||
s.backward();
|
||||
ASSERT_EQ(s.ndimension(), 0);
|
||||
|
||||
ASSERT_EQ(y.ndimension(), x.ndimension());
|
||||
ASSERT_EQ(y.sizes(), x.sizes());
|
||||
const auto y_exp = (x < 0) * model->weight * x + (x >= 0) * x;
|
||||
ASSERT_TRUE(torch::allclose(y, y_exp));
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, PrettyPrintIdentity) {
|
||||
ASSERT_EQ(c10::str(Identity()), "torch::nn::Identity()");
|
||||
}
|
||||
@ -1480,3 +1503,9 @@ TEST_F(ModulesTest, PrettyPrintLogSigmoid) {
|
||||
TEST_F(ModulesTest, PrettyPrintSoftmax) {
|
||||
ASSERT_EQ(c10::str(Softmax(SoftmaxOptions(1))), "torch::nn::Softmax(dim=1)");
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, PrettyPrintPReLU) {
|
||||
ASSERT_EQ(c10::str(PReLU()), "torch::nn::PReLU(num_parameters=1)");
|
||||
ASSERT_EQ(c10::str(PReLU(PReLUOptions().num_parameters(42))),
|
||||
"torch::nn::PReLU(num_parameters=42)");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user