mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Add inplace tests for several torch::nn modules / functionals (#35147)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/35147 Test Plan: Imported from OSS Differential Revision: D20578217 Pulled By: yf225 fbshipit-source-id: b8bafa49ee94c7dfbbca6e100ee3d9df5b2b621c
This commit is contained in:
committed by
Facebook GitHub Bot
parent
f515d87296
commit
bbec4520c6
@ -1300,54 +1300,81 @@ TEST_F(ModulesTest, FeatureAlphaDropout) {
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, Dropout) {
|
||||
Dropout dropout(0.5);
|
||||
torch::Tensor x = torch::ones(100, torch::requires_grad());
|
||||
torch::Tensor y = dropout(x);
|
||||
for (const auto inplace : {false, true}) {
|
||||
Dropout dropout(DropoutOptions(0.5).inplace(inplace));
|
||||
torch::Tensor x = torch::ones(100);
|
||||
if (!inplace) {
|
||||
x.requires_grad_(true);
|
||||
}
|
||||
torch::Tensor y = dropout(x);
|
||||
|
||||
y.backward(torch::ones_like(y));
|
||||
ASSERT_EQ(y.ndimension(), 1);
|
||||
ASSERT_EQ(y.size(0), 100);
|
||||
ASSERT_LT(y.sum().item<float>(), 130); // Probably
|
||||
ASSERT_GT(y.sum().item<float>(), 70); // Probably
|
||||
ASSERT_EQ(y.ndimension(), 1);
|
||||
ASSERT_EQ(y.size(0), 100);
|
||||
ASSERT_LT(y.sum().item<float>(), 130); // Probably
|
||||
ASSERT_GT(y.sum().item<float>(), 70); // Probably
|
||||
if (inplace) {
|
||||
ASSERT_TRUE(y.allclose(x));
|
||||
} else {
|
||||
y.backward(torch::ones_like(y));
|
||||
}
|
||||
|
||||
dropout->eval();
|
||||
y = dropout(x);
|
||||
ASSERT_EQ(y.sum().item<float>(), 100);
|
||||
dropout->eval();
|
||||
y = dropout(torch::ones(100));
|
||||
ASSERT_EQ(y.sum().item<float>(), 100);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, Dropout2d) {
|
||||
Dropout2d dropout(0.5);
|
||||
torch::Tensor x = torch::ones({10, 10}, torch::requires_grad());
|
||||
torch::Tensor y = dropout(x);
|
||||
for (const auto inplace : {false, true}) {
|
||||
Dropout2d dropout(Dropout2dOptions(0.5).inplace(inplace));
|
||||
torch::Tensor x = torch::ones({10, 10});
|
||||
if (!inplace) {
|
||||
x.requires_grad_(true);
|
||||
}
|
||||
torch::Tensor y = dropout(x);
|
||||
|
||||
y.backward(torch::ones_like(y));
|
||||
ASSERT_EQ(y.ndimension(), 2);
|
||||
ASSERT_EQ(y.size(0), 10);
|
||||
ASSERT_EQ(y.size(1), 10);
|
||||
ASSERT_LT(y.sum().item<float>(), 130); // Probably
|
||||
ASSERT_GT(y.sum().item<float>(), 70); // Probably
|
||||
ASSERT_EQ(y.ndimension(), 2);
|
||||
ASSERT_EQ(y.size(0), 10);
|
||||
ASSERT_EQ(y.size(1), 10);
|
||||
ASSERT_LT(y.sum().item<float>(), 130); // Probably
|
||||
ASSERT_GT(y.sum().item<float>(), 70); // Probably
|
||||
if (inplace) {
|
||||
ASSERT_TRUE(y.allclose(x));
|
||||
} else {
|
||||
y.backward(torch::ones_like(y));
|
||||
}
|
||||
|
||||
dropout->eval();
|
||||
y = dropout(x);
|
||||
ASSERT_EQ(y.sum().item<float>(), 100);
|
||||
dropout->eval();
|
||||
y = dropout(torch::ones({10, 10}));
|
||||
ASSERT_EQ(y.sum().item<float>(), 100);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, Dropout3d) {
|
||||
Dropout3d dropout(0.5);
|
||||
torch::Tensor x = torch::ones({4, 5, 5}, torch::requires_grad());
|
||||
torch::Tensor y = dropout(x);
|
||||
for (const auto inplace : {false, true}) {
|
||||
Dropout3d dropout(Dropout3dOptions(0.5).inplace(inplace));
|
||||
torch::Tensor x = torch::ones({4, 5, 5});
|
||||
if (!inplace) {
|
||||
x.requires_grad_(true);
|
||||
}
|
||||
torch::Tensor y = dropout(x);
|
||||
|
||||
y.backward(torch::ones_like(y));
|
||||
ASSERT_EQ(y.ndimension(), 3);
|
||||
ASSERT_EQ(y.size(0), 4);
|
||||
ASSERT_EQ(y.size(1), 5);
|
||||
ASSERT_EQ(y.size(1), 5);
|
||||
ASSERT_LT(y.sum().item<float>(), 130); // Probably
|
||||
ASSERT_GT(y.sum().item<float>(), 70); // Probably
|
||||
ASSERT_EQ(y.ndimension(), 3);
|
||||
ASSERT_EQ(y.size(0), 4);
|
||||
ASSERT_EQ(y.size(1), 5);
|
||||
ASSERT_EQ(y.size(1), 5);
|
||||
ASSERT_LT(y.sum().item<float>(), 130); // Probably
|
||||
ASSERT_GT(y.sum().item<float>(), 70); // Probably
|
||||
if (inplace) {
|
||||
ASSERT_TRUE(y.allclose(x));
|
||||
} else {
|
||||
y.backward(torch::ones_like(y));
|
||||
}
|
||||
|
||||
dropout->eval();
|
||||
y = dropout(x);
|
||||
ASSERT_EQ(y.sum().item<float>(), 100);
|
||||
dropout->eval();
|
||||
y = dropout(torch::ones({4, 5, 5}));
|
||||
ASSERT_EQ(y.sum().item<float>(), 100);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, Parameters) {
|
||||
@ -2147,38 +2174,58 @@ TEST_F(ModulesTest, PairwiseDistance) {
|
||||
TEST_F(ModulesTest, ELU) {
|
||||
const auto size = 3;
|
||||
for (const auto alpha : {0.0, 0.42, 1.0, 4.2, 42.42}) {
|
||||
ELU model {ELUOptions().alpha(alpha)};
|
||||
auto x = torch::linspace(-10.0, 10.0, size * size * size);
|
||||
x.resize_({size, size, size}).set_requires_grad(true);
|
||||
auto y = model(x);
|
||||
torch::Tensor s = y.sum();
|
||||
for (const auto inplace : {false, true}) {
|
||||
ELU model {ELUOptions().alpha(alpha).inplace(inplace)};
|
||||
auto x = torch::linspace(-10.0, 10.0, size * size * size);
|
||||
x.resize_({size, size, size});
|
||||
if (!inplace) {
|
||||
x.requires_grad_(true);
|
||||
}
|
||||
auto x_orig = x.clone();
|
||||
auto y = model(x);
|
||||
torch::Tensor s = y.sum();
|
||||
|
||||
s.backward();
|
||||
ASSERT_EQ(s.ndimension(), 0);
|
||||
ASSERT_EQ(s.ndimension(), 0);
|
||||
|
||||
ASSERT_EQ(y.ndimension(), 3);
|
||||
ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
|
||||
auto y_exp = torch::max(torch::zeros_like(x), x) +
|
||||
torch::min(torch::zeros_like(x), alpha * (torch::exp(x) - 1.0));
|
||||
ASSERT_TRUE(torch::allclose(y, y_exp));
|
||||
ASSERT_EQ(y.ndimension(), 3);
|
||||
ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
|
||||
auto y_exp = torch::max(torch::zeros_like(x_orig), x_orig) +
|
||||
torch::min(torch::zeros_like(x_orig), alpha * (torch::exp(x_orig) - 1.0));
|
||||
ASSERT_TRUE(torch::allclose(y, y_exp));
|
||||
if (inplace) {
|
||||
ASSERT_TRUE(torch::allclose(x, y_exp));
|
||||
} else {
|
||||
s.backward();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, SELU) {
|
||||
SELU model;
|
||||
auto input = torch::randn({5, 5}, torch::requires_grad());
|
||||
auto output = model->forward(input);
|
||||
const double scale = 1.0507009873554804934193349852946;
|
||||
const double alpha = 1.6732632423543772848170429916717;
|
||||
auto zero = torch::zeros_like(input);
|
||||
auto expected = scale *
|
||||
(torch::max(zero, input) +
|
||||
torch::min(zero, alpha * (torch::exp(input) - 1)));
|
||||
auto s = output.sum();
|
||||
s.backward();
|
||||
for (const auto inplace : {false, true}) {
|
||||
SELU model(inplace);
|
||||
auto input = torch::randn({5, 5});
|
||||
if (!inplace) {
|
||||
input.requires_grad_(true);
|
||||
}
|
||||
auto input_orig = input.clone();
|
||||
auto output = model->forward(input);
|
||||
const double scale = 1.0507009873554804934193349852946;
|
||||
const double alpha = 1.6732632423543772848170429916717;
|
||||
auto zero = torch::zeros_like(input);
|
||||
auto expected = scale *
|
||||
(torch::max(zero, input_orig) +
|
||||
torch::min(zero, alpha * (torch::exp(input_orig) - 1)));
|
||||
auto s = output.sum();
|
||||
|
||||
ASSERT_EQ(s.ndimension(), 0);
|
||||
ASSERT_TRUE(output.allclose(expected));
|
||||
ASSERT_EQ(s.ndimension(), 0);
|
||||
ASSERT_TRUE(output.allclose(expected));
|
||||
if (inplace) {
|
||||
ASSERT_TRUE(input.allclose(expected));
|
||||
} else {
|
||||
s.backward();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, Hardshrink) {
|
||||
@ -2192,7 +2239,6 @@ TEST_F(ModulesTest, Hardshrink) {
|
||||
|
||||
s.backward();
|
||||
ASSERT_EQ(s.ndimension(), 0);
|
||||
|
||||
ASSERT_EQ(y.ndimension(), 3);
|
||||
ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
|
||||
auto y_exp = (x.abs() > lambda) * x;
|
||||
@ -2204,21 +2250,30 @@ TEST_F(ModulesTest, Hardtanh) {
|
||||
const auto size = 3;
|
||||
for (const auto min_val : {-4.2, -1.0, -0.42, 0.0}) {
|
||||
for (const auto max_val : {0.42, 1.0, 4.2}) {
|
||||
Hardtanh model {HardtanhOptions().min_val(min_val).max_val(max_val)};
|
||||
auto x = torch::linspace(-10.0, 10.0, size * size * size);
|
||||
x.resize_({size, size, size}).set_requires_grad(true);
|
||||
auto y = model(x);
|
||||
torch::Tensor s = y.sum();
|
||||
for (const auto inplace : {false, true}) {
|
||||
Hardtanh model {HardtanhOptions().min_val(min_val).max_val(max_val).inplace(inplace)};
|
||||
auto x = torch::linspace(-10.0, 10.0, size * size * size);
|
||||
x.resize_({size, size, size});
|
||||
if (!inplace) {
|
||||
x.requires_grad_(true);
|
||||
}
|
||||
auto x_orig = x.clone();
|
||||
auto y = model(x);
|
||||
torch::Tensor s = y.sum();
|
||||
|
||||
s.backward();
|
||||
ASSERT_EQ(s.ndimension(), 0);
|
||||
|
||||
ASSERT_EQ(y.ndimension(), 3);
|
||||
ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
|
||||
auto y_exp = (x < min_val) * min_val +
|
||||
((x >= min_val) * (x <= max_val)) * x +
|
||||
(x > max_val) * max_val;
|
||||
ASSERT_TRUE(torch::allclose(y, y_exp));
|
||||
ASSERT_EQ(s.ndimension(), 0);
|
||||
ASSERT_EQ(y.ndimension(), 3);
|
||||
ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
|
||||
auto y_exp = (x_orig < min_val) * min_val +
|
||||
((x_orig >= min_val) * (x_orig <= max_val)) * x_orig +
|
||||
(x_orig > max_val) * max_val;
|
||||
ASSERT_TRUE(torch::allclose(y, y_exp));
|
||||
if (inplace) {
|
||||
ASSERT_TRUE(torch::allclose(x, y_exp));
|
||||
} else {
|
||||
s.backward();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -2238,20 +2293,29 @@ TEST_F(ModulesTest, HardtanhMinValGEMaxVal) {
|
||||
|
||||
TEST_F(ModulesTest, LeakyReLU) {
|
||||
const auto size = 3;
|
||||
for (const auto negative_slope : {0.0, 0.42, 1.0}) {
|
||||
LeakyReLU model {LeakyReLUOptions().negative_slope(negative_slope)};
|
||||
auto x = torch::linspace(-10.0, 10.0, size * size * size);
|
||||
x.resize_({size, size, size}).set_requires_grad(true);
|
||||
auto y = model(x);
|
||||
torch::Tensor s = y.sum();
|
||||
for (const auto inplace : {false, true}) {
|
||||
for (const auto negative_slope : {0.0, 0.42, 1.0}) {
|
||||
LeakyReLU model {LeakyReLUOptions().negative_slope(negative_slope).inplace(inplace)};
|
||||
auto x = torch::linspace(-10.0, 10.0, size * size * size);
|
||||
x.resize_({size, size, size});
|
||||
if (!inplace) {
|
||||
x.requires_grad_(true);
|
||||
}
|
||||
auto x_orig = x.clone();
|
||||
auto y = model(x);
|
||||
torch::Tensor s = y.sum();
|
||||
|
||||
s.backward();
|
||||
ASSERT_EQ(s.ndimension(), 0);
|
||||
|
||||
ASSERT_EQ(y.ndimension(), 3);
|
||||
ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
|
||||
auto y_exp = (x < 0) * x * negative_slope + (x >= 0) * x;
|
||||
ASSERT_TRUE(torch::allclose(y, y_exp));
|
||||
ASSERT_EQ(s.ndimension(), 0);
|
||||
ASSERT_EQ(y.ndimension(), 3);
|
||||
ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
|
||||
auto y_exp = (x_orig < 0) * x_orig * negative_slope + (x_orig >= 0) * x_orig;
|
||||
ASSERT_TRUE(torch::allclose(y, y_exp));
|
||||
if (inplace) {
|
||||
ASSERT_TRUE(torch::allclose(x, y_exp));
|
||||
} else {
|
||||
s.backward();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -2394,78 +2458,114 @@ TEST_F(ModulesTest, PReLU) {
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, ReLU) {
|
||||
const auto size = 3;
|
||||
ReLU model;
|
||||
auto x = torch::linspace(-10.0, 10.0, size * size * size);
|
||||
x.resize_({size, size, size}).set_requires_grad(true);
|
||||
auto y = model(x);
|
||||
torch::Tensor s = y.sum();
|
||||
for (const auto inplace : {false, true}) {
|
||||
const auto size = 3;
|
||||
ReLU model(inplace);
|
||||
auto x = torch::linspace(-10.0, 10.0, size * size * size);
|
||||
x.resize_({size, size, size});
|
||||
if (!inplace) {
|
||||
x.requires_grad_(true);
|
||||
}
|
||||
auto x_orig = x.clone();
|
||||
auto y = model(x);
|
||||
torch::Tensor s = y.sum();
|
||||
|
||||
s.backward();
|
||||
ASSERT_EQ(s.ndimension(), 0);
|
||||
|
||||
ASSERT_EQ(y.ndimension(), 3);
|
||||
ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
|
||||
auto y_exp = (x < 0) * 0 + (x >= 0) * x;
|
||||
ASSERT_TRUE(torch::allclose(y, y_exp));
|
||||
ASSERT_EQ(s.ndimension(), 0);
|
||||
ASSERT_EQ(y.ndimension(), 3);
|
||||
ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
|
||||
auto y_exp = (x_orig < 0) * 0 + (x_orig >= 0) * x_orig;
|
||||
ASSERT_TRUE(torch::allclose(y, y_exp));
|
||||
if (inplace) {
|
||||
ASSERT_TRUE(torch::allclose(x, y_exp));
|
||||
} else {
|
||||
s.backward();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, ReLU6) {
|
||||
const auto size = 3;
|
||||
ReLU6 model;
|
||||
auto x = torch::linspace(-10.0, 10.0, size * size * size);
|
||||
x.resize_({size, size, size}).set_requires_grad(true);
|
||||
auto y = model(x);
|
||||
torch::Tensor s = y.sum();
|
||||
for (const auto inplace : {false, true}) {
|
||||
const auto size = 3;
|
||||
ReLU6 model(inplace);
|
||||
auto x = torch::linspace(-10.0, 10.0, size * size * size);
|
||||
x.resize_({size, size, size});
|
||||
if (!inplace) {
|
||||
x.requires_grad_(true);
|
||||
}
|
||||
auto x_orig = x.clone();
|
||||
auto y = model(x);
|
||||
torch::Tensor s = y.sum();
|
||||
|
||||
s.backward();
|
||||
ASSERT_EQ(s.ndimension(), 0);
|
||||
|
||||
ASSERT_EQ(y.ndimension(), 3);
|
||||
ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
|
||||
auto y_exp = (x < 0) * 0 + ((x >= 0) * (x <= 6)) * x + (x > 6) * 6;
|
||||
ASSERT_TRUE(torch::allclose(y, y_exp));
|
||||
ASSERT_EQ(s.ndimension(), 0);
|
||||
ASSERT_EQ(y.ndimension(), 3);
|
||||
ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
|
||||
auto y_exp = (x_orig < 0) * 0 + ((x_orig >= 0) * (x_orig <= 6)) * x_orig + (x_orig > 6) * 6;
|
||||
ASSERT_TRUE(torch::allclose(y, y_exp));
|
||||
if (inplace) {
|
||||
ASSERT_TRUE(torch::allclose(x, y_exp));
|
||||
} else {
|
||||
s.backward();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, RReLU) {
|
||||
const auto size = 3;
|
||||
for (const auto lower : {0.01, 0.1, 0.2}) {
|
||||
for (const auto upper : {0.3, 0.4, 0.5}) {
|
||||
RReLU model {RReLUOptions().lower(lower).upper(upper)};
|
||||
auto x = torch::linspace(-10.0, 10.0, size * size * size);
|
||||
x.resize_({size, size, size}).set_requires_grad(true);
|
||||
auto y = model(x);
|
||||
torch::Tensor s = y.sum();
|
||||
for (const auto inplace : {false, true}) {
|
||||
RReLU model {RReLUOptions().lower(lower).upper(upper).inplace(inplace)};
|
||||
auto x = torch::linspace(-10.0, 10.0, size * size * size);
|
||||
x.resize_({size, size, size});
|
||||
if (!inplace) {
|
||||
x.requires_grad_(true);
|
||||
}
|
||||
auto x_orig = x.clone();
|
||||
auto y = model(x);
|
||||
torch::Tensor s = y.sum();
|
||||
|
||||
s.backward();
|
||||
ASSERT_EQ(s.ndimension(), 0);
|
||||
|
||||
ASSERT_EQ(y.ndimension(), 3);
|
||||
ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
|
||||
auto z = ((x >= 0) * (x == y) +
|
||||
(x < 0) * (y >= x * upper) * (y <= lower * x)) * 1.0;
|
||||
ASSERT_TRUE(torch::allclose(z, torch::ones_like(z)));
|
||||
ASSERT_EQ(s.ndimension(), 0);
|
||||
ASSERT_EQ(y.ndimension(), 3);
|
||||
ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
|
||||
auto z = ((x_orig >= 0) * (x_orig == y) +
|
||||
(x_orig < 0) * (y >= x_orig * upper) * (y <= lower * x_orig)) * 1.0;
|
||||
ASSERT_TRUE(torch::allclose(z, torch::ones_like(z)));
|
||||
if (inplace) {
|
||||
ASSERT_TRUE(torch::allclose(x, y));
|
||||
} else {
|
||||
s.backward();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, CELU) {
|
||||
const auto size = 3;
|
||||
for (const auto alpha : {0.42, 1.0, 4.2, 42.42}) {
|
||||
CELU model {CELUOptions().alpha(alpha)};
|
||||
auto x = torch::linspace(-10.0, 10.0, size * size * size);
|
||||
x.resize_({size, size, size}).set_requires_grad(true);
|
||||
auto y = model(x);
|
||||
torch::Tensor s = y.sum();
|
||||
for (const auto inplace : {false, true}) {
|
||||
for (const auto alpha : {0.42, 1.0, 4.2, 42.42}) {
|
||||
CELU model {CELUOptions().alpha(alpha).inplace(inplace)};
|
||||
auto x = torch::linspace(-10.0, 10.0, size * size * size);
|
||||
x.resize_({size, size, size});
|
||||
if (!inplace) {
|
||||
x.requires_grad_(true);
|
||||
}
|
||||
auto x_orig = x.clone();
|
||||
auto y = model(x);
|
||||
torch::Tensor s = y.sum();
|
||||
|
||||
s.backward();
|
||||
ASSERT_EQ(s.ndimension(), 0);
|
||||
|
||||
ASSERT_EQ(y.ndimension(), 3);
|
||||
ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
|
||||
auto y_exp = torch::max(torch::zeros_like(x), x) +
|
||||
torch::min(torch::zeros_like(x), alpha * (torch::exp(x / alpha) - 1.0));
|
||||
ASSERT_TRUE(torch::allclose(y, y_exp));
|
||||
ASSERT_EQ(s.ndimension(), 0);
|
||||
ASSERT_EQ(y.ndimension(), 3);
|
||||
ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
|
||||
auto y_exp = torch::max(torch::zeros_like(x_orig), x_orig) +
|
||||
torch::min(torch::zeros_like(x_orig), alpha * (torch::exp(x_orig / alpha) - 1.0));
|
||||
ASSERT_TRUE(torch::allclose(y, y_exp));
|
||||
if (inplace) {
|
||||
ASSERT_TRUE(torch::allclose(x, y_exp));
|
||||
} else {
|
||||
s.backward();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -2597,12 +2697,16 @@ TEST_F(ModulesTest, Threshold) {
|
||||
Threshold model {ThresholdOptions(threshold, value).inplace(inplace)};
|
||||
auto x = torch::linspace(-3.0, 3.0, 61);
|
||||
x.resize_({size, size, size});
|
||||
auto y_exp = (x <= threshold) * value + (x > threshold) * x;
|
||||
auto x_orig = x.clone();
|
||||
auto y_exp = (x_orig <= threshold) * value + (x_orig > threshold) * x_orig;
|
||||
auto y = model(x);
|
||||
|
||||
ASSERT_EQ(y.ndimension(), 3);
|
||||
ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
|
||||
ASSERT_TRUE(torch::allclose(y, y_exp));
|
||||
if (inplace) {
|
||||
ASSERT_TRUE(torch::allclose(x, y_exp));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user