mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Extract reusable portions of elu_kernel into header (#149673)
Similar to #140425, we are making the implementation usable via header-only code sharing. Review note: #62546 by @yanbing-j removed expm1 usage from this path. I don't know why and expm1 should be more efficient, so I've put it back. Please let me know if there is a good reason I shouldn't. Testing: existing correctness tests should cover. Pull Request resolved: https://github.com/pytorch/pytorch/pull/149673 Approved by: https://github.com/cyyever, https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
b238e36fd9
commit
c73a526599
@ -2432,8 +2432,7 @@ TEST_F(ModulesTest, ELU) {
|
||||
ASSERT_EQ(y.ndimension(), 3);
|
||||
ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
|
||||
auto y_exp = torch::max(torch::zeros_like(x_orig), x_orig) +
|
||||
torch::min(torch::zeros_like(x_orig),
|
||||
alpha * (torch::exp(x_orig) - 1.0));
|
||||
torch::min(torch::zeros_like(x_orig), alpha * (torch::expm1(x_orig)));
|
||||
ASSERT_TRUE(torch::allclose(y, y_exp));
|
||||
if (inplace) {
|
||||
ASSERT_TRUE(torch::allclose(x, y_exp));
|
||||
@ -2458,7 +2457,7 @@ TEST_F(ModulesTest, SELU) {
|
||||
auto zero = torch::zeros_like(input);
|
||||
auto expected = scale *
|
||||
(torch::max(zero, input_orig) +
|
||||
torch::min(zero, alpha * (torch::exp(input_orig) - 1)));
|
||||
torch::min(zero, alpha * (torch::expm1(input_orig))));
|
||||
auto s = output.sum();
|
||||
|
||||
ASSERT_EQ(s.ndimension(), 0);
|
||||
@ -2848,7 +2847,7 @@ TEST_F(ModulesTest, CELU) {
|
||||
ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
|
||||
auto y_exp = torch::max(torch::zeros_like(x_orig), x_orig) +
|
||||
torch::min(torch::zeros_like(x_orig),
|
||||
alpha * (torch::exp(x_orig / alpha) - 1.0));
|
||||
alpha * (torch::expm1(x_orig / alpha)));
|
||||
ASSERT_TRUE(torch::allclose(y, y_exp));
|
||||
if (inplace) {
|
||||
ASSERT_TRUE(torch::allclose(x, y_exp));
|
||||
|
Reference in New Issue
Block a user