mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 16:44:58 +08:00
C++ API parity: at::Tensor::grad
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/26150 Test Plan: Imported from OSS Differential Revision: D17427579 Pulled By: pbelevich fbshipit-source-id: 68d012076aa86dee9f23fad71a2d265d75f56d22
This commit is contained in:
committed by
Facebook Github Bot
parent
72aeafd3d0
commit
98ccae09af
@ -216,7 +216,7 @@ TEST_F(ModulesTest, Linear) {
|
||||
|
||||
TEST_F(ModulesTest, Fold) {
|
||||
Fold model(FoldOptions({4, 5}, {2, 2}));
|
||||
auto x = torch::randn({1, 3 * 2 * 2, 12});
|
||||
auto x = torch::randn({1, 3 * 2 * 2, 12}, torch::requires_grad());
|
||||
auto y = model(x);
|
||||
torch::Tensor s = y.sum();
|
||||
|
||||
@ -240,7 +240,7 @@ TEST_F(ModulesTest, SimpleContainer) {
|
||||
x = l2(x).clamp_min(0);
|
||||
x = l3(x).clamp_min(0);
|
||||
|
||||
x.backward();
|
||||
x.backward(torch::ones_like(x));
|
||||
ASSERT_EQ(x.ndimension(), 2);
|
||||
ASSERT_EQ(x.size(0), 1000);
|
||||
ASSERT_EQ(x.size(1), 100);
|
||||
@ -288,7 +288,7 @@ TEST_F(ModulesTest, Dropout) {
|
||||
torch::Tensor x = torch::ones(100, torch::requires_grad());
|
||||
torch::Tensor y = dropout(x);
|
||||
|
||||
y.backward();
|
||||
y.backward(torch::ones_like(y));
|
||||
ASSERT_EQ(y.ndimension(), 1);
|
||||
ASSERT_EQ(y.size(0), 100);
|
||||
ASSERT_LT(y.sum().item<float>(), 130); // Probably
|
||||
|
||||
Reference in New Issue
Block a user