C++ API parity: at::Tensor::grad

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/26150

Test Plan: Imported from OSS

Differential Revision: D17427579

Pulled By: pbelevich

fbshipit-source-id: 68d012076aa86dee9f23fad71a2d265d75f56d22
This commit is contained in:
Pavel Belevich
2019-09-18 09:19:00 -07:00
committed by Facebook Github Bot
parent 72aeafd3d0
commit 98ccae09af
6 changed files with 34 additions and 22 deletions

View File

@ -216,7 +216,7 @@ TEST_F(ModulesTest, Linear) {
TEST_F(ModulesTest, Fold) {
Fold model(FoldOptions({4, 5}, {2, 2}));
auto x = torch::randn({1, 3 * 2 * 2, 12});
auto x = torch::randn({1, 3 * 2 * 2, 12}, torch::requires_grad());
auto y = model(x);
torch::Tensor s = y.sum();
@ -240,7 +240,7 @@ TEST_F(ModulesTest, SimpleContainer) {
x = l2(x).clamp_min(0);
x = l3(x).clamp_min(0);
x.backward();
x.backward(torch::ones_like(x));
ASSERT_EQ(x.ndimension(), 2);
ASSERT_EQ(x.size(0), 1000);
ASSERT_EQ(x.size(1), 100);
@ -288,7 +288,7 @@ TEST_F(ModulesTest, Dropout) {
torch::Tensor x = torch::ones(100, torch::requires_grad());
torch::Tensor y = dropout(x);
y.backward();
y.backward(torch::ones_like(y));
ASSERT_EQ(y.ndimension(), 1);
ASSERT_EQ(y.size(0), 100);
ASSERT_LT(y.sum().item<float>(), 130); // Probably