mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
ENH Adds label_smoothing to cross entropy loss (#63122)
Summary: Fixes https://github.com/pytorch/pytorch/issues/7455 Partially resolves pytorch/vision#4281 Pull Request resolved: https://github.com/pytorch/pytorch/pull/63122 Reviewed By: iramazanli Differential Revision: D30586076 Pulled By: jbschlosser fbshipit-source-id: 06afc3aa1f8b9edb07fe9ed68c58968ad1926924
This commit is contained in:
committed by
Facebook GitHub Bot
parent
8af1407eab
commit
d3bcba5f85
@ -2315,6 +2315,31 @@ TEST_F(ModulesTest, CrossEntropyLoss) {
|
||||
ASSERT_TRUE(
|
||||
CrossEntropyLoss(CrossEntropyLossOptions().ignore_index(-100).reduction(torch::kMean))
|
||||
->forward(input, target).allclose(expected, 1e-04));
|
||||
|
||||
// label smoothing with class indices
|
||||
loss = CrossEntropyLoss(CrossEntropyLossOptions().label_smoothing(0.15).reduction(torch::kMean));
|
||||
input = torch::tensor({{3., 1.}, {1., 2.}}, torch::dtype(torch::kFloat).requires_grad(true));
|
||||
target = torch::tensor({0, 1}, torch::kLong);
|
||||
output = loss->forward(input, target);
|
||||
expected = torch::tensor(0.3326, torch::kFloat);
|
||||
s = output.sum();
|
||||
s.backward();
|
||||
|
||||
ASSERT_TRUE(output.allclose(expected, 1e-04));
|
||||
ASSERT_EQ(input.sizes(), input.grad().sizes());
|
||||
|
||||
// label smoothing with with target probabilities
|
||||
loss = CrossEntropyLoss(CrossEntropyLossOptions().label_smoothing(0.2).reduction(torch::kMean));
|
||||
input = torch::tensor({{3., 1.}, {1., 2.}}, torch::dtype(torch::kFloat).requires_grad(true));
|
||||
target = torch::tensor({{0.8, 0.2}, {0.1, 0.9}}, torch::kFloat);
|
||||
output = loss->forward(input, target);
|
||||
expected = torch::tensor(0.5701, torch::kFloat);
|
||||
s = output.sum();
|
||||
s.backward();
|
||||
|
||||
ASSERT_TRUE(output.allclose(expected, 1e-04));
|
||||
ASSERT_EQ(input.sizes(), input.grad().sizes());
|
||||
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, CosineSimilarity) {
|
||||
|
Reference in New Issue
Block a user