mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 16:44:58 +08:00
Add the appropriate check on div_value to the cpp frontend (#114671)
Fixes #114334 As the title stated. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114671 Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
@ -2679,6 +2679,18 @@ TEST_F(ModulesTest, AdaptiveLogSoftmaxWithLoss) {
|
||||
ASSERT_TRUE(
|
||||
torch::allclose(asfm(x, y).output.squeeze(0), asfm(x2, y2).output));
|
||||
}
|
||||
{
|
||||
// test div_value
|
||||
auto options =
|
||||
AdaptiveLogSoftmaxWithLossOptions(16, 20, {4, 10, 15}).div_value(0.);
|
||||
ASSERT_THROWS_WITH(
|
||||
AdaptiveLogSoftmaxWithLoss(options),
|
||||
"div_value should not be equal to 0");
|
||||
|
||||
options =
|
||||
AdaptiveLogSoftmaxWithLossOptions(16, 20, {4, 10, 15}).div_value(0.25);
|
||||
ASSERT_TRUE(AdaptiveLogSoftmaxWithLoss(options));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, Softmax2d) {
|
||||
|
||||
Reference in New Issue
Block a user