Add the appropriate check on div_value to the cpp frontend (#114671)

Fixes #114334

As the title stated.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114671
Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
FFFrog
2023-12-04 01:28:11 +00:00
committed by PyTorch MergeBot
parent 50833021dd
commit 541591dd79
2 changed files with 15 additions and 2 deletions

View File

@ -2679,6 +2679,18 @@ TEST_F(ModulesTest, AdaptiveLogSoftmaxWithLoss) {
ASSERT_TRUE(
torch::allclose(asfm(x, y).output.squeeze(0), asfm(x2, y2).output));
}
{
// test div_value
auto options =
AdaptiveLogSoftmaxWithLossOptions(16, 20, {4, 10, 15}).div_value(0.);
ASSERT_THROWS_WITH(
AdaptiveLogSoftmaxWithLoss(options),
"div_value should not be equal to 0");
options =
AdaptiveLogSoftmaxWithLossOptions(16, 20, {4, 10, 15}).div_value(0.25);
ASSERT_TRUE(AdaptiveLogSoftmaxWithLoss(options));
}
}
TEST_F(ModulesTest, Softmax2d) {