mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 21:49:24 +08:00
AdaptiveLogSoftmaxWithLoss no_batch_dim support (#69054)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/69054 Test Plan: Imported from OSS Reviewed By: jbschlosser Differential Revision: D33200166 Pulled By: george-qi fbshipit-source-id: 9d953744351a25f372418d2a64e8402356d1e9b7
This commit is contained in:
committed by
Facebook GitHub Bot
parent
0460324b9b
commit
8af39b7668
@ -2668,6 +2668,15 @@ TEST_F(ModulesTest, AdaptiveLogSoftmaxWithLoss) {
|
||||
ASSERT_TRUE(torch::allclose(out, logprob_out.gather(1, y.unsqueeze(1)).squeeze()));
|
||||
}
|
||||
}
|
||||
{
|
||||
// test no batch dim
|
||||
AdaptiveLogSoftmaxWithLoss asfm(AdaptiveLogSoftmaxWithLossOptions(16, 20, {4, 10, 15}).div_value(2.));
|
||||
auto x = torch::randn({1, 16});
|
||||
auto y = torch::tensor({17});
|
||||
auto x2 = x.squeeze(0);
|
||||
auto y2 = y.squeeze(0);
|
||||
ASSERT_TRUE(torch::allclose(asfm(x, y).output.squeeze(0), asfm(x2, y2).output));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, Softmax2d) {
|
||||
|
Reference in New Issue
Block a user