AdaptiveLogSoftmaxWithLoss no_batch_dim support (#69054)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/69054

Test Plan: Imported from OSS

Reviewed By: jbschlosser

Differential Revision: D33200166

Pulled By: george-qi

fbshipit-source-id: 9d953744351a25f372418d2a64e8402356d1e9b7
This commit is contained in:
George Qi
2021-12-29 10:20:01 -08:00
committed by Facebook GitHub Bot
parent 0460324b9b
commit 8af39b7668
4 changed files with 72 additions and 12 deletions

View File

@ -2668,6 +2668,15 @@ TEST_F(ModulesTest, AdaptiveLogSoftmaxWithLoss) {
ASSERT_TRUE(torch::allclose(out, logprob_out.gather(1, y.unsqueeze(1)).squeeze()));
}
}
{
// test no batch dim
AdaptiveLogSoftmaxWithLoss asfm(AdaptiveLogSoftmaxWithLossOptions(16, 20, {4, 10, 15}).div_value(2.));
auto x = torch::randn({1, 16});
auto y = torch::tensor({17});
auto x2 = x.squeeze(0);
auto y2 = y.squeeze(0);
ASSERT_TRUE(torch::allclose(asfm(x, y).output.squeeze(0), asfm(x2, y2).output));
}
}
TEST_F(ModulesTest, Softmax2d) {