mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Enable BFloat16 LeakyReLU and RReLU in CPU path (#61514)
Summary: Enable and optimize BFloat16 LeakyReLU and RReLU in CPU path. Pull Request resolved: https://github.com/pytorch/pytorch/pull/61514 Reviewed By: ejguan Differential Revision: D30257612 Pulled By: VitalyFedyunin fbshipit-source-id: 8cc0d1faacd02dcc9827af724a86d95b6952748f
This commit is contained in:
committed by
Facebook GitHub Bot
parent
2ca2761f3c
commit
33a163d886
@ -2521,25 +2521,27 @@ TEST_F(ModulesTest, LeakyReLU) {
|
||||
const auto size = 3;
|
||||
for (const auto inplace : {false, true}) {
|
||||
for (const auto negative_slope : {0.0, 0.42, 1.0}) {
|
||||
LeakyReLU model {LeakyReLUOptions().negative_slope(negative_slope).inplace(inplace)};
|
||||
auto x = torch::linspace(-10.0, 10.0, size * size * size);
|
||||
x.resize_({size, size, size});
|
||||
if (!inplace) {
|
||||
x.requires_grad_(true);
|
||||
}
|
||||
auto x_orig = x.clone();
|
||||
auto y = model(x);
|
||||
torch::Tensor s = y.sum();
|
||||
for (const auto type : {torch::kFloat, torch::kBFloat16}) {
|
||||
LeakyReLU model {LeakyReLUOptions().negative_slope(negative_slope).inplace(inplace)};
|
||||
auto x = torch::linspace(-10.0, 10.0, size * size * size).to(type);
|
||||
x.resize_({size, size, size});
|
||||
if (!inplace) {
|
||||
x.requires_grad_(true);
|
||||
}
|
||||
auto x_orig = x.clone();
|
||||
auto y = model(x);
|
||||
torch::Tensor s = y.sum();
|
||||
|
||||
ASSERT_EQ(s.ndimension(), 0);
|
||||
ASSERT_EQ(y.ndimension(), 3);
|
||||
ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
|
||||
auto y_exp = (x_orig < 0) * x_orig * negative_slope + (x_orig >= 0) * x_orig;
|
||||
ASSERT_TRUE(torch::allclose(y, y_exp));
|
||||
if (inplace) {
|
||||
ASSERT_TRUE(torch::allclose(x, y_exp));
|
||||
} else {
|
||||
s.backward();
|
||||
ASSERT_EQ(s.ndimension(), 0);
|
||||
ASSERT_EQ(y.ndimension(), 3);
|
||||
ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
|
||||
auto y_exp = (x_orig < 0) * x_orig * negative_slope + (x_orig >= 0) * x_orig;
|
||||
ASSERT_TRUE(torch::allclose(y, y_exp));
|
||||
if (inplace) {
|
||||
ASSERT_TRUE(torch::allclose(x, y_exp));
|
||||
} else {
|
||||
s.backward();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -2740,26 +2742,28 @@ TEST_F(ModulesTest, RReLU) {
|
||||
for (const auto lower : {0.01, 0.1, 0.2}) {
|
||||
for (const auto upper : {0.3, 0.4, 0.5}) {
|
||||
for (const auto inplace : {false, true}) {
|
||||
RReLU model {RReLUOptions().lower(lower).upper(upper).inplace(inplace)};
|
||||
auto x = torch::linspace(-10.0, 10.0, size * size * size);
|
||||
x.resize_({size, size, size});
|
||||
if (!inplace) {
|
||||
x.requires_grad_(true);
|
||||
}
|
||||
auto x_orig = x.clone();
|
||||
auto y = model(x);
|
||||
torch::Tensor s = y.sum();
|
||||
for (const auto type : {torch::kFloat, torch::kBFloat16}) {
|
||||
RReLU model {RReLUOptions().lower(lower).upper(upper).inplace(inplace)};
|
||||
auto x = torch::linspace(-10.0, 10.0, size * size * size).to(type);
|
||||
x.resize_({size, size, size});
|
||||
if (!inplace) {
|
||||
x.requires_grad_(true);
|
||||
}
|
||||
auto x_orig = x.clone();
|
||||
auto y = model(x);
|
||||
torch::Tensor s = y.sum();
|
||||
|
||||
ASSERT_EQ(s.ndimension(), 0);
|
||||
ASSERT_EQ(y.ndimension(), 3);
|
||||
ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
|
||||
auto z = ((x_orig >= 0) * (x_orig == y) +
|
||||
(x_orig < 0) * (y >= x_orig * upper) * (y <= lower * x_orig)) * 1.0;
|
||||
ASSERT_TRUE(torch::allclose(z, torch::ones_like(z)));
|
||||
if (inplace) {
|
||||
ASSERT_TRUE(torch::allclose(x, y));
|
||||
} else {
|
||||
s.backward();
|
||||
ASSERT_EQ(s.ndimension(), 0);
|
||||
ASSERT_EQ(y.ndimension(), 3);
|
||||
ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
|
||||
auto z = ((x_orig >= 0) * (x_orig == y) +
|
||||
(x_orig < 0) * (y >= x_orig * upper) * (y <= lower * x_orig)) * 1.0;
|
||||
ASSERT_TRUE(torch::allclose(z, torch::ones_like(z)));
|
||||
if (inplace) {
|
||||
ASSERT_TRUE(torch::allclose(x, y));
|
||||
} else {
|
||||
s.backward();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user