mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:30:26 +08:00
Fixes issues introduced in https://github.com/pytorch/pytorch/pull/141348 and https://github.com/pytorch/pytorch/pull/139578 Pull Request resolved: https://github.com/pytorch/pytorch/pull/142514 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
1713 lines
52 KiB
C++
1713 lines
52 KiB
C++
#include <ATen/core/boxing/impl/test_helpers.h>
|
|
#include <gtest/gtest.h>
|
|
|
|
#include <ATen/core/op_registration/op_registration.h>
|
|
#include <torch/torch.h>
|
|
|
|
#include <torch/csrc/autograd/FunctionsManual.h>
|
|
#include <torch/csrc/autograd/engine.h>
|
|
#include <torch/csrc/autograd/functions/basic_ops.h>
|
|
|
|
#include <test/cpp/api/support.h>
|
|
|
|
using namespace torch::autograd;
|
|
using namespace torch::test;
|
|
|
|
#define ASSERT_VARIABLE_EQ(a, b) ASSERT_TRUE(torch::allclose((a), (b)))
|
|
#define EXPECT_VARIABLE_EQ(a, b) EXPECT_TRUE(torch::allclose((a), (b)))
|
|
|
|
std::string graph_desc(std::shared_ptr<Node> node) {
|
|
if (!node) {
|
|
return "None";
|
|
}
|
|
auto result = node->name() + "(";
|
|
auto next_edges = node->next_edges();
|
|
for (auto& edge : next_edges) {
|
|
result += graph_desc(edge.function);
|
|
}
|
|
return result + ")";
|
|
}
|
|
|
|
Variable simple_fn(const Variable& x, const Variable& y) {
|
|
return x + 2 * y + x * y;
|
|
}
|
|
|
|
TEST(AutogradAPITests, RegisterHookVoidReturnAcceptsUndefinedTensor) {
|
|
auto x = at::zeros({}, at::kCPU);
|
|
x.requires_grad_();
|
|
x.register_hook([](at::TensorBase x) { return; });
|
|
auto y = torch::autograd::UndefinedGrad().apply({x});
|
|
y[0].backward();
|
|
}
|
|
|
|
TEST(AutogradAPITests, RegisterHookTensorReturnAcceptsUndefinedTensor) {
|
|
auto x = at::zeros({}, at::kCPU);
|
|
x.requires_grad_();
|
|
x.register_hook([](at::Tensor x) -> at::Tensor { return x; });
|
|
auto y = torch::autograd::UndefinedGrad().apply({x});
|
|
y[0].backward();
|
|
}
|
|
|
|
TEST(AutogradAPITests, BackwardSimpleTest) {
|
|
Variable x = torch::randn({2, 2}, torch::requires_grad());
|
|
Variable y = torch::randn({2, 2}, torch::requires_grad());
|
|
auto res = simple_fn(x, y);
|
|
backward({res.sum()}, {});
|
|
|
|
ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({2, 2}));
|
|
ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({2, 2}) * 2);
|
|
}
|
|
|
|
TEST(AutogradAPITests, BackwardTest) {
|
|
Variable x = torch::randn({2, 2}, torch::requires_grad());
|
|
Variable y = torch::randn({2, 2}, torch::requires_grad());
|
|
auto res = simple_fn(x, y);
|
|
backward({res}, {torch::ones({2, 2})}, {}, true);
|
|
|
|
backward({res}, {torch::ones({2, 2})});
|
|
|
|
ASSERT_VARIABLE_EQ(x.grad(), 2 * (y + torch::ones({2, 2})));
|
|
ASSERT_VARIABLE_EQ(y.grad(), 2 * (x + torch::ones({2, 2}) * 2));
|
|
}
|
|
|
|
TEST(AutogradAPITests, GradSimpleTest) {
|
|
// basic grad
|
|
Variable x = torch::randn({2, 2}, torch::requires_grad());
|
|
Variable y = torch::randn({2, 2}, torch::requires_grad());
|
|
auto res = simple_fn(x, y);
|
|
auto grad_res = grad({res}, {x, y}, {torch::ones({2, 2})});
|
|
|
|
ASSERT_VARIABLE_EQ(grad_res[0], y + torch::ones({2, 2}));
|
|
ASSERT_VARIABLE_EQ(grad_res[1], x + torch::ones({2, 2}) * 2);
|
|
}
|
|
|
|
TEST(AutogradAPITests, GradTest) {
|
|
Variable x = torch::randn({2, 2}, torch::requires_grad());
|
|
Variable y = torch::randn({2, 2}, torch::requires_grad());
|
|
auto res = simple_fn(x, y);
|
|
res.backward(torch::ones({2, 2}), false, true);
|
|
|
|
Variable x_grad = y + torch::ones({2, 2});
|
|
Variable y_grad = x + torch::ones({2, 2}) * 2;
|
|
ASSERT_VARIABLE_EQ(x.grad(), x_grad);
|
|
ASSERT_VARIABLE_EQ(y.grad(), y_grad);
|
|
|
|
Variable grad_sum = 2 * x.grad() + y.grad();
|
|
auto x_hv = grad({grad_sum}, {x}, {torch::ones({2, 2})}, {}, true);
|
|
|
|
ASSERT_VARIABLE_EQ(x_hv[0], torch::ones({2, 2}));
|
|
ASSERT_VARIABLE_EQ(x.grad(), x_grad);
|
|
ASSERT_VARIABLE_EQ(y.grad(), y_grad);
|
|
}
|
|
|
|
TEST(AutogradAPITests, GradNonLeafTest) {
|
|
Variable x_init = torch::randn({2, 2}, torch::requires_grad());
|
|
Variable x = x_init;
|
|
Variable y = torch::randn({2, 2}, torch::requires_grad());
|
|
Variable grad_output = torch::ones({2, 2});
|
|
|
|
for (int i = 0; i < 5; ++i) {
|
|
auto res = simple_fn(x, y);
|
|
auto input_grads = grad({res}, {x}, {grad_output}, {}, true);
|
|
|
|
Variable grad_x_expected = y + torch::ones({2, 2});
|
|
ASSERT_VARIABLE_EQ(input_grads[0], grad_x_expected);
|
|
ASSERT_FALSE(x.grad().defined());
|
|
ASSERT_FALSE(y.grad().defined());
|
|
x = x + 0.05 * input_grads[0];
|
|
}
|
|
|
|
float val_init = simple_fn(x_init, y).sum().item().toFloat();
|
|
float val_final = simple_fn(x, y).sum().item().toFloat();
|
|
ASSERT_TRUE(val_final > val_init);
|
|
|
|
x.backward(grad_output, false, true);
|
|
ASSERT_TRUE(x_init.grad().defined());
|
|
ASSERT_TRUE(y.grad().defined());
|
|
}
|
|
|
|
TEST(AutogradAPITests, GradUnreachableTest) {
|
|
Variable x = torch::ones({1}, torch::requires_grad());
|
|
Variable y = torch::ones({1}, torch::requires_grad());
|
|
|
|
Variable z = x * 2;
|
|
Variable w = y * 2;
|
|
|
|
auto grad_res = grad({x * 2}, {x, y}, {}, {}, false, true);
|
|
ASSERT_VARIABLE_EQ(grad_res[0], x * 2);
|
|
ASSERT_FALSE(grad_res[1].defined());
|
|
|
|
// This is slightly different than the case above, because z doesn't even
|
|
// have a grad accumulator allocated.
|
|
z = torch::ones({1}, torch::requires_grad());
|
|
grad_res = grad({x * 2}, {x, z}, {}, {}, false, true);
|
|
|
|
ASSERT_VARIABLE_EQ(grad_res[0], x * 2);
|
|
ASSERT_FALSE(grad_res[1].defined());
|
|
|
|
// allow_unused=False, but grads contains None inside, should throw
|
|
ASSERT_THROWS_WITH(
|
|
grad({x * 2}, {x, y}, {}, {}, false, false), "Set allow_unused=True");
|
|
}
|
|
|
|
TEST(CustomAutogradTest, GradUnreachableDiscoveryTest) {
|
|
// Test that certain nodes are not erroneously executed when an input
|
|
// is unreachable. See #39784
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext* ctx, Variable var) {
|
|
return var;
|
|
}
|
|
|
|
static variable_list backward(
|
|
AutogradContext* ctx,
|
|
variable_list grad_output) {
|
|
ADD_FAILURE() << "This node should not be executed!";
|
|
return grad_output;
|
|
}
|
|
};
|
|
|
|
auto x = torch::randn(1, torch::requires_grad());
|
|
auto x1 = torch::randn(1);
|
|
auto x2 = MyFunction::apply(x + x1);
|
|
|
|
auto y = torch::randn(1, torch::requires_grad());
|
|
auto grad_res = torch::autograd::grad({x2}, {y}, {}, {}, false, true);
|
|
ASSERT_FALSE(grad_res[0].defined());
|
|
}
|
|
|
|
TEST(AutogradAPITests, EmptyInput) {
|
|
Variable x = torch::ones({1}, torch::requires_grad());
|
|
ASSERT_THROWS_WITH(
|
|
grad({x * 2}, /*inputs=*/{}, {x}), "grad requires non-empty inputs.");
|
|
}
|
|
|
|
TEST(AutogradAPITests, RetainGrad) {
|
|
auto input = torch::rand({1, 3}, torch::requires_grad());
|
|
auto h1 = input * 3;
|
|
auto out = (h1 * h1).sum();
|
|
|
|
{
|
|
// Warning when grad is accessed for non-leaf tensor
|
|
WarningCapture warnings;
|
|
ASSERT_FALSE(h1.grad().defined());
|
|
ASSERT_TRUE(warnings.str().find("is not a leaf") != std::string::npos);
|
|
}
|
|
// It should be possible to call retain_grad() multiple times
|
|
h1.retain_grad();
|
|
h1.retain_grad();
|
|
{
|
|
// If retain_grad is true for a non-leaf tensor,
|
|
// there should not be any warning when grad is accessed
|
|
WarningCapture warnings;
|
|
ASSERT_FALSE(h1.grad().defined());
|
|
ASSERT_FALSE(warnings.str().find("is not a leaf") != std::string::npos);
|
|
}
|
|
|
|
// Gradient should be accumulated
|
|
// NOLINTNEXTLINE(bugprone-argument-comment)
|
|
out.backward({}, /*keep_graph=*/true);
|
|
ASSERT_VARIABLE_EQ(h1 * 2, h1.grad());
|
|
// NOLINTNEXTLINE(bugprone-argument-comment)
|
|
out.backward({}, /*keep_graph=*/true);
|
|
ASSERT_VARIABLE_EQ(h1 * 4, h1.grad());
|
|
|
|
{
|
|
torch::NoGradGuard no_grad;
|
|
input.grad().zero_();
|
|
}
|
|
// It should be a no-op for leaves
|
|
input.retain_grad();
|
|
input.retain_grad();
|
|
out.backward();
|
|
ASSERT_VARIABLE_EQ(input * 18, input.grad());
|
|
}
|
|
|
|
TEST(AutogradAPITests, AnomalyMode) {
|
|
// Needs to have backtrace as warning and then throw an error
|
|
torch::autograd::DetectAnomalyGuard detect_anomaly;
|
|
{
|
|
WarningCapture warnings;
|
|
auto x = torch::tensor({5.0}, torch::requires_grad());
|
|
auto y = x * x;
|
|
auto z = y * y;
|
|
y += 1;
|
|
ASSERT_THROWS_WITH(z.backward(), "inplace");
|
|
ASSERT_TRUE(
|
|
warnings.str().find("Traceback of forward") != std::string::npos);
|
|
}
|
|
auto double_backward_produce_nan = [](bool should_throw) {
|
|
auto x = torch::tensor({0.0}, torch::requires_grad());
|
|
auto y = x.pow(1.5);
|
|
auto gr =
|
|
// NOLINTNEXTLINE(bugprone-argument-comment)
|
|
grad({y}, {x}, {}, /*retain_graph=*/true, /*create_backward=*/true);
|
|
if (should_throw) {
|
|
WarningCapture warnings;
|
|
ASSERT_THROWS_WITH(grad({gr[0]}, {x}, {torch::tensor({0.0})});
|
|
, "returned nan");
|
|
auto msgs = warnings.messages();
|
|
ASSERT_EQ(msgs.size(), 2);
|
|
ASSERT_TRUE(
|
|
msgs[0].find("Traceback of forward call that caused the error") !=
|
|
std::string::npos);
|
|
ASSERT_TRUE(
|
|
msgs[1].find(
|
|
"Traceback of forward call that induced the previous calculation") !=
|
|
std::string::npos);
|
|
} else {
|
|
grad({gr[0]}, {x}, {torch::tensor({0.0})});
|
|
}
|
|
};
|
|
|
|
double_backward_produce_nan(true);
|
|
{
|
|
torch::autograd::DetectAnomalyGuard detect_anomaly(/*check_nan=*/false);
|
|
double_backward_produce_nan(false);
|
|
{
|
|
torch::autograd::DetectAnomalyGuard detect_anomaly(/*check_nan=*/true);
|
|
double_backward_produce_nan(true);
|
|
}
|
|
}
|
|
double_backward_produce_nan(true);
|
|
}
|
|
|
|
TEST(CustomAutogradTest, CustomFunctionReturnInputAsIsAndSavesIt) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(
|
|
AutogradContext* ctx,
|
|
Variable var1,
|
|
Variable var2) {
|
|
ctx->save_for_backward({var1, var2});
|
|
return var1 * var2, var1;
|
|
}
|
|
|
|
static variable_list backward(
|
|
AutogradContext* ctx,
|
|
variable_list grad_output) {
|
|
return {};
|
|
}
|
|
};
|
|
|
|
Variable x = torch::randn({5, 5}, torch::requires_grad());
|
|
Variable y = torch::randn({5, 5}, torch::requires_grad());
|
|
MyFunction::apply(x, y);
|
|
}
|
|
|
|
TEST(CustomAutogradTest, CustomFunction) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(
|
|
AutogradContext* ctx,
|
|
Variable var1,
|
|
int mul,
|
|
Variable var2) {
|
|
ctx->saved_data["mul"] = mul;
|
|
ctx->save_for_backward({var1, var2});
|
|
return var1 + mul * var2 + var1 * var2;
|
|
}
|
|
|
|
static variable_list backward(
|
|
AutogradContext* ctx,
|
|
variable_list grad_output) {
|
|
int mul = ctx->saved_data["mul"].toInt();
|
|
auto saved = ctx->get_saved_variables();
|
|
auto var1 = saved[0];
|
|
auto var2 = saved[1];
|
|
variable_list output = {
|
|
grad_output[0] + grad_output[0] * var2,
|
|
Variable(),
|
|
grad_output[0] * mul + grad_output[0] * var1};
|
|
return output;
|
|
}
|
|
};
|
|
|
|
Variable x = torch::randn({5, 5}, torch::requires_grad());
|
|
Variable y = torch::randn({5, 5}, torch::requires_grad());
|
|
auto res = MyFunction::apply(x, 2, y);
|
|
auto go = torch::ones({}, torch::requires_grad());
|
|
res.sum().backward(go, false, true);
|
|
|
|
ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({5, 5}));
|
|
ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({5, 5}) * 2);
|
|
}
|
|
|
|
TEST(CustomAutogradTest, CustomFunctionWithTensorList) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext* ctx, at::TensorList tensors) {
|
|
torch::autograd::variable_list vars;
|
|
for (const at::Tensor& tensor : tensors) {
|
|
vars.push_back(tensor);
|
|
}
|
|
ctx->save_for_backward(vars);
|
|
return tensors[0] + tensors[1] + tensors[0] * tensors[1];
|
|
}
|
|
|
|
static variable_list backward(
|
|
AutogradContext* ctx,
|
|
variable_list grad_output) {
|
|
auto saved = ctx->get_saved_variables();
|
|
auto var1 = saved[0];
|
|
auto var2 = saved[1];
|
|
variable_list output = {
|
|
grad_output[0] + grad_output[0] * var2,
|
|
grad_output[0] + grad_output[0] * var1};
|
|
return output;
|
|
}
|
|
};
|
|
|
|
at::Tensor x = torch::randn({5, 5}, torch::requires_grad());
|
|
at::Tensor y = torch::randn({5, 5}, torch::requires_grad());
|
|
torch::autograd::variable_list variables = {x, y};
|
|
at::TensorList tensors = variables;
|
|
auto res = MyFunction::apply(tensors);
|
|
auto go = torch::ones({}, torch::requires_grad());
|
|
res.sum().backward(go, false, true);
|
|
|
|
ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({5, 5}));
|
|
ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({5, 5}));
|
|
}
|
|
|
|
TEST(CustomAutogradTest, GraphTaskTrimEdges) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(
|
|
AutogradContext* ctx,
|
|
Variable var1,
|
|
Variable var2,
|
|
int mul,
|
|
bool needs_input1_grad,
|
|
bool needs_input2_grad) {
|
|
// setup the expected should and should not compute idx
|
|
ctx->saved_data["needs_input1_grad"] = needs_input1_grad;
|
|
ctx->saved_data["needs_input2_grad"] = needs_input2_grad;
|
|
|
|
ctx->saved_data["mul"] = mul;
|
|
ctx->save_for_backward({var1, var2});
|
|
return var1 + mul * var2 + var1 * var2;
|
|
}
|
|
|
|
static variable_list backward(
|
|
AutogradContext* ctx,
|
|
variable_list grad_output) {
|
|
// Test `needs_input_grad` method is working correctly.
|
|
// We have to test this within the backward function.
|
|
auto needs_input1_grad = ctx->saved_data["needs_input1_grad"].toBool();
|
|
auto needs_input2_grad = ctx->saved_data["needs_input2_grad"].toBool();
|
|
IndexRange var1_idx = {0, 1};
|
|
IndexRange var2_idx = {1, 2};
|
|
EXPECT_EQ(ctx->needs_input_grad(0), needs_input1_grad);
|
|
EXPECT_EQ(ctx->needs_input_grad(1), needs_input2_grad);
|
|
EXPECT_EQ(ctx->needs_input_grad({var1_idx}), needs_input1_grad);
|
|
EXPECT_EQ(ctx->needs_input_grad({var2_idx}), needs_input2_grad);
|
|
EXPECT_EQ(
|
|
ctx->needs_input_grad({var1_idx, var2_idx}),
|
|
needs_input1_grad || needs_input2_grad);
|
|
|
|
// calculate gradients
|
|
int mul = ctx->saved_data["mul"].toInt();
|
|
auto saved = ctx->get_saved_variables();
|
|
auto var1 = saved[0];
|
|
auto var2 = saved[1];
|
|
|
|
Variable grad_var1, grad_var2;
|
|
if (ctx->needs_input_grad(0)) {
|
|
grad_var1 = grad_output[0] + grad_output[0] * var2;
|
|
}
|
|
if (ctx->needs_input_grad(1)) {
|
|
grad_var2 = grad_output[0] * mul + grad_output[0] * var1;
|
|
}
|
|
variable_list output = {
|
|
grad_var1,
|
|
grad_var2,
|
|
Variable(),
|
|
Variable(),
|
|
Variable(),
|
|
};
|
|
return output;
|
|
}
|
|
};
|
|
|
|
Variable x = torch::randn({5, 5}, torch::requires_grad());
|
|
Variable y = torch::randn({5, 5}, torch::requires_grad());
|
|
auto go = torch::ones_like(x);
|
|
Variable out;
|
|
|
|
// grad_x
|
|
out = MyFunction::apply(
|
|
x,
|
|
y,
|
|
2,
|
|
/* needs_input1_grad= */ true,
|
|
/* needs_input2_grad= */ false);
|
|
auto grad_x = torch::autograd::grad({out}, {x}, {go})[0];
|
|
ASSERT_VARIABLE_EQ(grad_x, y + torch::ones({5, 5}));
|
|
|
|
// grad_y
|
|
out = MyFunction::apply(
|
|
x,
|
|
y,
|
|
2,
|
|
/* needs_input1_grad= */ false,
|
|
/* needs_input2_grad= */ true);
|
|
auto grad_y = torch::autograd::grad({out}, {y}, {go})[0];
|
|
ASSERT_VARIABLE_EQ(grad_y, x + torch::ones({5, 5}) * 2);
|
|
|
|
// grad_x and grad_y
|
|
out = MyFunction::apply(
|
|
x,
|
|
y,
|
|
2,
|
|
/* needs_input1_grad= */ true,
|
|
/* needs_input2_grad= */ true);
|
|
auto grads = torch::autograd::grad({out}, {x, y}, {go});
|
|
ASSERT_VARIABLE_EQ(grads[0], y + torch::ones({5, 5}));
|
|
ASSERT_VARIABLE_EQ(grads[1], x + torch::ones({5, 5}) * 2);
|
|
}
|
|
|
|
TEST(CustomAutogradTest, FunctionReturnsInput) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext* ctx, Variable var1) {
|
|
return var1;
|
|
}
|
|
|
|
static variable_list backward(
|
|
AutogradContext* ctx,
|
|
variable_list grad_output) {
|
|
return {grad_output[0] * 2};
|
|
}
|
|
};
|
|
|
|
Variable x(torch::ones(1, torch::requires_grad()));
|
|
MyFunction::apply(x).backward(torch::ones(1), true, true);
|
|
ASSERT_VARIABLE_EQ(x.grad(), torch::full(1, 2.));
|
|
}
|
|
|
|
TEST(CustomAutogradTest, FunctionReturnsUndefined) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext* ctx, Variable var) {
|
|
return var * 2;
|
|
}
|
|
|
|
static variable_list backward(
|
|
AutogradContext* ctx,
|
|
variable_list grad_output) {
|
|
at::Tensor undefined_tensor;
|
|
return {undefined_tensor};
|
|
}
|
|
};
|
|
|
|
auto x = torch::ones(1, torch::requires_grad());
|
|
|
|
MyFunction::apply(x).backward();
|
|
ASSERT_FALSE(x.grad().defined());
|
|
|
|
MyFunction::apply(x.pow(2)).backward();
|
|
ASSERT_FALSE(x.grad().defined());
|
|
|
|
MyFunction::apply(x).sum().backward();
|
|
ASSERT_FALSE(x.grad().defined());
|
|
|
|
ASSERT_FALSE(torch::autograd::grad(
|
|
{MyFunction::apply(x)}, {x}, {}, false, false, true)[0]
|
|
.defined());
|
|
}
|
|
|
|
TEST(CustomAutogradTest, MaterializeGrads) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext* ctx, Variable var) {
|
|
return var;
|
|
}
|
|
|
|
static variable_list backward(
|
|
AutogradContext* ctx,
|
|
variable_list grad_output) {
|
|
EXPECT_VARIABLE_EQ(grad_output[0], torch::zeros(1));
|
|
return grad_output;
|
|
}
|
|
};
|
|
|
|
auto x = torch::ones(1, torch::requires_grad());
|
|
UndefinedGrad().apply({MyFunction::apply(x)})[0].backward();
|
|
}
|
|
|
|
TEST(CustomAutogradTest, DontMaterializeGrads) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext* ctx, Variable var) {
|
|
ctx->set_materialize_grads(false);
|
|
return var;
|
|
}
|
|
|
|
static variable_list backward(
|
|
AutogradContext* ctx,
|
|
variable_list grad_output) {
|
|
EXPECT_FALSE(grad_output[0].defined());
|
|
return grad_output;
|
|
}
|
|
};
|
|
|
|
auto x = torch::ones(1, torch::requires_grad());
|
|
UndefinedGrad().apply({MyFunction::apply(x)})[0].backward();
|
|
}
|
|
|
|
TEST(CustomAutogradTest, NoGradCustomFunction) {
|
|
// Custom Function should respect grad mode
|
|
struct MyOp : public Function<MyOp> {
|
|
static Variable forward(AutogradContext* ctx, Variable x) {
|
|
return x + 1;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext* ctx, variable_list dy) {
|
|
return dy;
|
|
}
|
|
};
|
|
|
|
auto x = torch::ones({5, 5}, torch::requires_grad());
|
|
{
|
|
at::NoGradGuard no_grad;
|
|
auto y = MyOp::apply(x);
|
|
ASSERT_FALSE(y.requires_grad());
|
|
}
|
|
}
|
|
|
|
TEST(CustomAutogradTest, MarkDirty) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext* ctx, Variable v) {
|
|
// Change the value inplace
|
|
auto v_data = v.data_ptr<float>();
|
|
v_data[0] = 2;
|
|
ctx->mark_dirty({v});
|
|
return v;
|
|
}
|
|
|
|
static variable_list backward(
|
|
AutogradContext* ctx,
|
|
variable_list grad_output) {
|
|
return {(grad_output[0] * 2.0)};
|
|
}
|
|
};
|
|
|
|
// Clone here because modifying leafs inplace is not allowed
|
|
auto x = torch::randn({5, 5}, torch::requires_grad()).clone();
|
|
auto version_before = x._version();
|
|
auto out = MyFunction::apply(x);
|
|
auto version_after = x._version();
|
|
ASSERT_TRUE(version_after >= (version_before + 1));
|
|
out.sum().backward();
|
|
}
|
|
|
|
TEST(CustomAutogradTest, MarkNonDifferentiable) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext* ctx, Variable v) {
|
|
Variable output = v > 0;
|
|
ctx->mark_non_differentiable({output});
|
|
return output;
|
|
}
|
|
|
|
static variable_list backward(
|
|
AutogradContext* ctx,
|
|
variable_list grad_output) {
|
|
return {(grad_output[0] * 0.0)};
|
|
}
|
|
};
|
|
|
|
auto x = torch::randn({5, 5}, torch::requires_grad());
|
|
auto mask = MyFunction::apply(x);
|
|
ASSERT_FALSE(mask.requires_grad());
|
|
auto y = x.masked_fill(mask, 0);
|
|
y.sum().backward();
|
|
}
|
|
|
|
TEST(CustomAutogradTest, MarkNonDifferentiableMixed) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static variable_list forward(AutogradContext* ctx, Variable input) {
|
|
Variable a = input + 1;
|
|
Variable b = input + 2;
|
|
ctx->mark_non_differentiable({a});
|
|
return {a, b};
|
|
}
|
|
|
|
static variable_list backward(
|
|
AutogradContext* ctx,
|
|
variable_list grad_output) {
|
|
const Variable &grad_a = grad_output[0], &grad_b = grad_output[1];
|
|
EXPECT_VARIABLE_EQ(grad_a, torch::zeros({5, 5}));
|
|
EXPECT_VARIABLE_EQ(grad_b, torch::ones({5, 5}));
|
|
return {grad_b};
|
|
}
|
|
};
|
|
|
|
auto x = torch::randn({5, 5}, torch::requires_grad());
|
|
auto out = MyFunction::apply(x);
|
|
|
|
ASSERT_FALSE(out[0].requires_grad());
|
|
ASSERT_TRUE(out[1].requires_grad());
|
|
out[1].sum().backward();
|
|
ASSERT_VARIABLE_EQ(x.grad(), torch::ones({5, 5}));
|
|
}
|
|
|
|
TEST(CustomAutogradTest, MarkNonDifferentiableNone) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext* ctx, Variable input) {
|
|
auto output = input.clone();
|
|
ctx->mark_non_differentiable({output});
|
|
return output;
|
|
}
|
|
|
|
static variable_list backward(
|
|
AutogradContext* ctx,
|
|
variable_list grad_outputs) {
|
|
return {};
|
|
}
|
|
};
|
|
|
|
auto x = torch::randn({5, 5}, torch::requires_grad());
|
|
auto r = MyFunction::apply(x * x);
|
|
(r * x).sum().backward();
|
|
}
|
|
|
|
TEST(CustomAutogradTest, ReturnLeafInplace) {
|
|
struct Inplace : public Function<Inplace> {
|
|
static variable_list forward(AutogradContext* ctx, Variable a, Variable b) {
|
|
ctx->mark_dirty({a});
|
|
return {a.add_(b), b + 2};
|
|
}
|
|
|
|
static variable_list backward(
|
|
AutogradContext* ctx,
|
|
variable_list grad_output) {
|
|
return {grad_output[0], grad_output[0] + grad_output[1]};
|
|
}
|
|
};
|
|
|
|
Variable x = torch::randn({5, 5});
|
|
Variable y = torch::randn({5, 5}, torch::requires_grad());
|
|
|
|
auto out = Inplace::apply(x, y);
|
|
auto& q = out[0];
|
|
ASSERT_TRUE(torch::equal(q, x));
|
|
ASSERT_TRUE(q.requires_grad());
|
|
q.sum().backward();
|
|
ASSERT_VARIABLE_EQ(y.grad(), torch::ones({5, 5}));
|
|
}
|
|
|
|
TEST(CustomAutogradTest, ReturnDuplicateInplace) {
|
|
struct DoubleInplace : public Function<DoubleInplace> {
|
|
static variable_list forward(AutogradContext* ctx, Variable x) {
|
|
x.mul_(2);
|
|
ctx->mark_dirty({x});
|
|
return {x, x};
|
|
}
|
|
|
|
static variable_list backward(
|
|
AutogradContext* ctsx,
|
|
variable_list grad_outputs) {
|
|
return {grad_outputs[0] * 2 + grad_outputs[1] * 2};
|
|
}
|
|
};
|
|
|
|
auto x = torch::randn({5, 5}, torch::requires_grad());
|
|
|
|
ASSERT_THROWS_WITH(
|
|
DoubleInplace::apply(x), "leaf Variable that requires grad");
|
|
// TODO ASSERT_THROWS_WITH(DoubleInplace::apply(x.clone()[0]), "only one
|
|
// output");
|
|
|
|
auto out = DoubleInplace::apply(x.clone());
|
|
ASSERT_TRUE(torch::equal(out[0], out[1]));
|
|
}
|
|
|
|
TEST(CustomAutogradTest, ReturnDuplicate) {
|
|
struct DoubleDuplicate : public Function<DoubleDuplicate> {
|
|
static variable_list forward(AutogradContext* ctx, Variable x) {
|
|
auto output = x * 2;
|
|
return {output, output};
|
|
}
|
|
|
|
static variable_list backward(
|
|
AutogradContext* ctx,
|
|
variable_list grad_outputs) {
|
|
return {grad_outputs[0] * 2 + grad_outputs[1] * 2};
|
|
}
|
|
};
|
|
|
|
auto x = torch::randn({5, 5}, torch::requires_grad());
|
|
auto out = DoubleDuplicate::apply(x);
|
|
ASSERT_TRUE(torch::equal(out[0], out[1]));
|
|
}
|
|
|
|
TEST(CustomAutogradTest, SaveEmptyForBackward) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext* ctx, Variable input) {
|
|
ctx->save_for_backward({Variable(), input, Variable()});
|
|
return input * input;
|
|
}
|
|
|
|
static variable_list backward(
|
|
AutogradContext* ctx,
|
|
variable_list grad_output) {
|
|
auto saved = ctx->get_saved_variables();
|
|
EXPECT_FALSE(saved[0].defined());
|
|
EXPECT_FALSE(saved[2].defined());
|
|
return {saved[1] * 2 * grad_output[0]};
|
|
}
|
|
};
|
|
|
|
Variable x = torch::randn({5, 5}, torch::requires_grad());
|
|
auto y = MyFunction::apply(x);
|
|
y.sum().backward();
|
|
ASSERT_VARIABLE_EQ(x.grad(), 2 * x);
|
|
}
|
|
|
|
TEST(CustomAutogradTest, InvalidGradients) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext* ctx, Variable x) {
|
|
return x * 2;
|
|
}
|
|
|
|
static variable_list backward(
|
|
AutogradContext* ctsx,
|
|
variable_list grad_outputs) {
|
|
return {
|
|
torch::randn(10, torch::dtype(torch::kFloat).requires_grad(true))};
|
|
}
|
|
};
|
|
|
|
auto input1 =
|
|
torch::randn({5, 5}, torch::dtype(torch::kFloat).requires_grad(true));
|
|
ASSERT_THROWS_WITH(
|
|
MyFunction::apply(input1).sum().backward(), "expected shape");
|
|
auto input2 =
|
|
torch::randn(10, torch::dtype(torch::kDouble).requires_grad(true));
|
|
}
|
|
|
|
TEST(CustomAutogradTest, NoGradInput) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext*, Variable x) {
|
|
return x;
|
|
}
|
|
|
|
static variable_list backward(
|
|
AutogradContext*,
|
|
variable_list grad_outputs) {
|
|
return grad_outputs;
|
|
}
|
|
};
|
|
|
|
Variable x = torch::randn({5, 5}, torch::requires_grad());
|
|
Variable y;
|
|
{
|
|
at::NoGradGuard no_grad;
|
|
y = MyFunction::apply(x);
|
|
}
|
|
|
|
ASSERT_TRUE(x.requires_grad());
|
|
ASSERT_FALSE(y.grad_fn());
|
|
}
|
|
|
|
TEST(CustomAutogradTest, TooManyGrads) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext*, Variable input) {
|
|
return input;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext*, variable_list grad_output) {
|
|
grad_output.insert(grad_output.end(), {Variable(), Variable()});
|
|
return grad_output;
|
|
}
|
|
};
|
|
}
|
|
|
|
TEST(CustomAutogradTest, DepNoGrad) {
|
|
struct F1 : public Function<F1> {
|
|
static variable_list forward(AutogradContext* ctx, Variable input) {
|
|
auto out = torch::randn(input.sizes());
|
|
ctx->mark_non_differentiable({out});
|
|
return {input, out};
|
|
}
|
|
|
|
static variable_list backward(
|
|
AutogradContext* ctx,
|
|
variable_list grad_output) {
|
|
return {grad_output[0]};
|
|
}
|
|
};
|
|
|
|
struct F2 : public Function<F2> {
|
|
static Variable forward(AutogradContext*, Variable input, Variable ignore) {
|
|
return input;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext*, variable_list grad_output) {
|
|
return {grad_output[0], Variable()};
|
|
}
|
|
};
|
|
|
|
auto x = torch::randn(5, torch::requires_grad());
|
|
auto out = F1::apply(x);
|
|
Variable &a = out[0], &b = out[1];
|
|
b = b + 1; // Separate F1 and F2 by another operation
|
|
ASSERT_TRUE(a.requires_grad());
|
|
ASSERT_FALSE(b.requires_grad());
|
|
|
|
auto c = F2::apply(a, b);
|
|
c.backward(torch::ones(c.sizes()), false, false);
|
|
ASSERT_VARIABLE_EQ(x.grad(), torch::ones(x.sizes()));
|
|
}
|
|
|
|
TEST(CustomAutogradTest, Reentrant) {
|
|
static Variable y_data = torch::randn({2, 2});
|
|
struct Reenter : public Function<Reenter> {
|
|
static Variable forward(AutogradContext* ctx, Variable input) {
|
|
Variable output;
|
|
{
|
|
at::AutoGradMode enable_grad(true);
|
|
auto x = make_variable(input.tensor_data(), true);
|
|
auto y = make_variable(y_data.tensor_data(), true);
|
|
output = x * y;
|
|
|
|
ctx->saved_data["x"] = x;
|
|
ctx->saved_data["y"] = y;
|
|
ctx->saved_data["output_var"] = output;
|
|
}
|
|
return output.detach();
|
|
}
|
|
|
|
static variable_list backward(
|
|
AutogradContext* ctx,
|
|
variable_list grad_output) {
|
|
{
|
|
at::AutoGradMode enable_grad(true);
|
|
auto out = ctx->saved_data["output_var"].toTensor();
|
|
out.sum().backward();
|
|
}
|
|
return {ctx->saved_data["x"].toTensor().grad() * grad_output[0]};
|
|
}
|
|
};
|
|
|
|
auto x = torch::randn({2, 2}, torch::requires_grad());
|
|
auto out = Reenter::apply(x);
|
|
out.sum().backward();
|
|
ASSERT_VARIABLE_EQ(x.grad(), y_data);
|
|
}
|
|
|
|
// NOTE: If this fails for apparently unrelated reasons in TSAN be aware of
|
|
// the TSAN limit on mutex: https://github.com/google/sanitizers/issues/950
|
|
TEST(CustomAutogradTest, DeepReentrant) {
|
|
struct DeepReenter : public Function<DeepReenter> {
|
|
static Variable forward(AutogradContext* ctx, Variable x) {
|
|
{
|
|
at::AutoGradMode enable_grad(true);
|
|
ctx->saved_data["x"] = make_variable(x.tensor_data(), true) - 1;
|
|
}
|
|
return ctx->saved_data["x"].toTensor().detach();
|
|
}
|
|
|
|
static variable_list backward(
|
|
AutogradContext* ctx,
|
|
variable_list grad_output) {
|
|
if (!at::native::is_nonzero(ctx->saved_data["x"].toTensor())) {
|
|
return grad_output;
|
|
}
|
|
{
|
|
at::AutoGradMode enable_grad(true);
|
|
apply(ctx->saved_data["x"].toTensor())[0].sum().backward();
|
|
return grad_output;
|
|
}
|
|
}
|
|
};
|
|
|
|
// This should not stack overflow
|
|
auto v =
|
|
torch::tensor({8193}, torch::dtype(torch::kFloat).requires_grad(true));
|
|
DeepReenter::apply(v).sum().backward();
|
|
}
|
|
|
|
TEST(CustomAutogradTest, ReentrantPriority) {
|
|
static std::vector<int> order;
|
|
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext*, Variable x) {
|
|
return x;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext*, variable_list grad) {
|
|
order.push_back(0);
|
|
return grad;
|
|
}
|
|
};
|
|
|
|
struct Reenter : public Function<Reenter> {
|
|
static Variable forward(AutogradContext* ctx, Variable x) {
|
|
{
|
|
at::AutoGradMode enable_grad(true);
|
|
ctx->saved_data["x"] = make_variable(x.tensor_data(), true) - 1;
|
|
}
|
|
return ctx->saved_data["x"].toTensor().detach();
|
|
}
|
|
|
|
static variable_list backward(
|
|
AutogradContext* ctx,
|
|
variable_list grad_output) {
|
|
order.push_back(1);
|
|
if (!at::native::is_nonzero(ctx->saved_data["x"].toTensor())) {
|
|
return grad_output;
|
|
}
|
|
{
|
|
at::AutoGradMode enable_grad(true);
|
|
apply(ctx->saved_data["x"].toTensor())[0].sum().backward();
|
|
return grad_output;
|
|
}
|
|
}
|
|
};
|
|
|
|
auto a = MyFunction::apply(
|
|
torch::tensor({6}, torch::dtype(torch::kFloat).requires_grad(true)));
|
|
auto b = Reenter::apply(
|
|
torch::tensor({9}, torch::dtype(torch::kFloat).requires_grad(true)));
|
|
auto v = a * b;
|
|
v.backward();
|
|
|
|
// All the reentrant tasks should be prioritized over the MyFunction backward
|
|
// task.
|
|
ASSERT_EQ(order.size(), 10);
|
|
ASSERT_EQ(std::count(order.begin(), order.end(), 1), 9);
|
|
ASSERT_EQ(order.back(), 0);
|
|
// Clear static variable in case test get executed in a loop
|
|
order.clear();
|
|
}
|
|
|
|
TEST(CustomAutogradTest, Hooks) {
|
|
Variable x = torch::ones({5, 5}, torch::requires_grad());
|
|
Variable y = torch::ones({5, 5}) * 4;
|
|
y.set_requires_grad(true);
|
|
|
|
int counter = 0;
|
|
|
|
std::function<void(int, Variable)> bw_hook(
|
|
[&counter](int inc, Variable grad) { counter += inc; });
|
|
|
|
Variable z = x * x + x * 2 + x * y + y;
|
|
x.register_hook([&bw_hook](Variable grad) { bw_hook(0, grad); });
|
|
auto hook_1 =
|
|
z.register_hook([&bw_hook](Variable grad) { bw_hook(1, grad); });
|
|
z.backward(torch::ones({5, 5}), true, true);
|
|
ASSERT_EQ(counter, 1);
|
|
|
|
auto hook_2 =
|
|
z.register_hook([&bw_hook](Variable grad) { bw_hook(2, grad); });
|
|
z.backward(torch::ones({5, 5}), true, true);
|
|
ASSERT_EQ(counter, 4);
|
|
|
|
z.remove_hook(hook_2);
|
|
z.backward(torch::ones({5, 5}), true, true);
|
|
ASSERT_EQ(counter, 5);
|
|
|
|
std::function<Variable(Variable)> bw_hook_modify(
|
|
[](Variable grad) { return grad.mul(2); });
|
|
|
|
z.remove_hook(hook_1);
|
|
z.register_hook(bw_hook_modify);
|
|
y.grad().zero_();
|
|
z.backward(torch::ones({5, 5}), true, false);
|
|
ASSERT_VARIABLE_EQ(y.grad(), (x + 1) * 2);
|
|
|
|
y.register_hook(bw_hook_modify);
|
|
y.grad().zero_();
|
|
z.backward(torch::ones({5, 5}), false, false);
|
|
ASSERT_VARIABLE_EQ(y.grad(), (x + 1) * 4);
|
|
|
|
ASSERT_THROWS_WITH(y.remove_hook(3), "Invalid index");
|
|
}
|
|
|
|
TEST(CustomAutogradTest, HooksInplace) {
|
|
auto a = torch::ones({5, 5}, torch::requires_grad()).clone();
|
|
|
|
int hook1_count = 0;
|
|
auto hook1 = ([&hook1_count](Variable grad) {
|
|
hook1_count++;
|
|
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2);
|
|
});
|
|
|
|
int hook2_count = 0;
|
|
auto hook2 = ([&hook2_count](Variable grad) {
|
|
hook2_count++;
|
|
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}));
|
|
});
|
|
|
|
a.register_hook(hook1);
|
|
a.mul_(2);
|
|
a.register_hook(hook2);
|
|
|
|
auto out = (a + 1).sum();
|
|
out.backward();
|
|
|
|
ASSERT_EQ(hook1_count, 1);
|
|
ASSERT_EQ(hook2_count, 1);
|
|
}
|
|
|
|
TEST(CustomAutogradTest, HooksInplaceWithRetainsGrad) {
|
|
auto a = torch::ones({5, 5}, torch::requires_grad()).clone();
|
|
|
|
int hook1_count = 0;
|
|
auto hook1 = ([&hook1_count](Variable grad) {
|
|
hook1_count++;
|
|
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2);
|
|
});
|
|
|
|
int hook2_count = 0;
|
|
auto hook2 = ([&hook2_count](Variable grad) {
|
|
hook2_count++;
|
|
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2);
|
|
});
|
|
|
|
int hook3_count = 0;
|
|
auto hook3 = ([&hook3_count](Variable grad) {
|
|
hook3_count++;
|
|
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}));
|
|
});
|
|
|
|
a.register_hook(hook1);
|
|
a.retain_grad();
|
|
a.register_hook(hook2);
|
|
|
|
a.mul_(2);
|
|
a.register_hook(hook3);
|
|
|
|
auto out = (a + 1).sum();
|
|
out.backward();
|
|
|
|
ASSERT_EQ(hook1_count, 1);
|
|
ASSERT_EQ(hook2_count, 1);
|
|
ASSERT_EQ(hook3_count, 1);
|
|
|
|
ASSERT_TRUE(a.retains_grad());
|
|
ASSERT_VARIABLE_EQ(a.grad(), torch::ones({5, 5}));
|
|
}
|
|
|
|
TEST(CustomAutogradTest, HooksInplaceTwiceWithRetainsGrad) {
|
|
auto a = torch::ones({5, 5}, torch::requires_grad()).clone();
|
|
|
|
int hook1_count = 0;
|
|
auto hook1 = ([&hook1_count](Variable grad) {
|
|
hook1_count++;
|
|
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 4);
|
|
});
|
|
|
|
int hook2_count = 0;
|
|
auto hook2 = ([&hook2_count](Variable grad) {
|
|
hook2_count++;
|
|
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 4);
|
|
});
|
|
|
|
int hook3_count = 0;
|
|
auto hook3 = ([&hook3_count](Variable grad) {
|
|
hook3_count++;
|
|
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}));
|
|
});
|
|
|
|
a.register_hook(hook1);
|
|
a.retain_grad();
|
|
a.register_hook(hook2);
|
|
|
|
a.mul_(2);
|
|
a.mul_(2);
|
|
a.register_hook(hook3);
|
|
|
|
auto out = (a + 1).sum();
|
|
out.backward();
|
|
|
|
ASSERT_EQ(hook1_count, 1);
|
|
ASSERT_EQ(hook2_count, 1);
|
|
ASSERT_EQ(hook3_count, 1);
|
|
|
|
ASSERT_TRUE(a.retains_grad());
|
|
ASSERT_VARIABLE_EQ(a.grad(), torch::ones({5, 5}));
|
|
}
|
|
|
|
TEST(CustomAutogradTest, HookNone) {
|
|
struct NoneGradientFunction : public Function<NoneGradientFunction> {
|
|
static variable_list forward(AutogradContext* ctx, Variable x, Variable y) {
|
|
return {x, y};
|
|
}
|
|
|
|
static variable_list backward(AutogradContext* ctx, variable_list grad) {
|
|
return {grad[0], Variable()};
|
|
}
|
|
};
|
|
|
|
bool was_called = false;
|
|
|
|
auto hook = ([&was_called](Variable grad) {
|
|
ASSERT_TRUE(grad.defined());
|
|
was_called = true;
|
|
});
|
|
|
|
auto x = torch::randn({5, 5}, torch::requires_grad());
|
|
auto y = torch::randn({5, 5});
|
|
|
|
auto out = NoneGradientFunction::apply(x, y);
|
|
Variable rx = x[0], ry = x[1];
|
|
|
|
rx.register_hook(hook);
|
|
ry.register_hook(hook);
|
|
(rx + ry).sum().backward();
|
|
ASSERT_TRUE(was_called);
|
|
}
|
|
|
|
TEST(CustomAutogradTest, BackwardWithInputs) {
|
|
Variable x = torch::randn({5, 5}, torch::requires_grad());
|
|
Variable y = torch::randn({5, 5}, torch::requires_grad());
|
|
Variable z = x * x + x * y + y * y;
|
|
Variable x_grad_expected = 2 * x + y;
|
|
Variable y_grad_expected = x + 2 * y;
|
|
|
|
z.backward(torch::ones({5, 5}), false, false, {x});
|
|
|
|
ASSERT_VARIABLE_EQ(x.grad(), x_grad_expected);
|
|
ASSERT_FALSE(y.grad().defined());
|
|
}
|
|
|
|
TEST(CustomAutogradTest, BackwardWithEmptyInputs) {
|
|
Variable x = torch::randn({5, 5}, torch::requires_grad());
|
|
Variable y = torch::randn({5, 5}, torch::requires_grad());
|
|
Variable z = x * x + x * y + y * y;
|
|
Variable x_grad_expected = 2 * x + y;
|
|
Variable y_grad_expected = x + 2 * y;
|
|
ASSERT_THROWS_WITH(
|
|
z.backward(torch::ones({5, 5}), false, false, std::vector<Variable>{}),
|
|
"cannot be empty");
|
|
}
|
|
|
|
TEST(CustomAutogradTest, BackwardWithNonLeafInputs) {
|
|
Variable x = torch::randn({5, 5}, torch::requires_grad());
|
|
Variable y = torch::randn({5, 5}, torch::requires_grad());
|
|
Variable z = x * x;
|
|
Variable w = y * z + x * y + y * y;
|
|
|
|
Variable x_grad_expected = 2 * x * y + y;
|
|
Variable z_grad_expected = y;
|
|
|
|
w.backward(torch::ones({5, 5}), false, false, std::vector<Variable>{x, z});
|
|
|
|
ASSERT_VARIABLE_EQ(x.grad(), x_grad_expected);
|
|
ASSERT_VARIABLE_EQ(z.grad(), z_grad_expected);
|
|
ASSERT_FALSE(y.grad().defined());
|
|
}
|
|
|
|
TEST(CustomAutogradTest, BackwardWithCreateGraphWarns) {
|
|
c10::WarningUtils::WarnAlways guard(true);
|
|
|
|
torch::Tensor x = torch::randn({5, 5}).set_requires_grad(true);
|
|
auto z = x * x;
|
|
{
|
|
WarningCapture warnings;
|
|
z.backward(torch::ones({5, 5}), std::nullopt, true);
|
|
ASSERT_TRUE(
|
|
warnings.str().find("Using backward() with create_graph=True") !=
|
|
std::string::npos);
|
|
}
|
|
|
|
{
|
|
WarningCapture warnings;
|
|
torch::autograd::backward({z}, {torch::ones({5, 5})}, std::nullopt, true);
|
|
ASSERT_TRUE(
|
|
warnings.str().find("Using backward() with create_graph=True") !=
|
|
std::string::npos);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Tests for AutogradNotImplementedFallback
|
|
* - Check that we created the NotImplemented kernel when inputs require grad
|
|
* but when no inputs require grad, we should not create this node
|
|
* - check_inplace logic
|
|
* - view ops
|
|
* - TODO: Tests for debug-only checks? Don't need for now because CI doesn't
|
|
* test non-NDEBUG builds.
|
|
* - tensorlist input and output
|
|
* - multiple outputs / non-tensor output
|
|
* - rebase_history vs set_history
|
|
*/
|
|
namespace {
|
|
|
|
torch::Tensor inplace_op(
|
|
const torch::Tensor& self,
|
|
const torch::Tensor& other) {
|
|
return self.add_(other);
|
|
}
|
|
|
|
std::tuple<torch::Tensor, torch::Tensor> two_arg_inplace_op(
|
|
const torch::Tensor& self,
|
|
const torch::Tensor& other) {
|
|
other.add_(self);
|
|
self.add_(other);
|
|
return std::tuple<torch::Tensor, torch::Tensor>(self, other);
|
|
}
|
|
|
|
std::tuple<torch::Tensor, torch::Tensor> two_pairs_of_view_op(
|
|
const torch::Tensor& self,
|
|
const torch::Tensor& other) {
|
|
// This is not allowed. We test below that this calling into the boxed kernel
|
|
// will raise an error
|
|
return std::tuple<torch::Tensor, torch::Tensor>(self, other);
|
|
}
|
|
|
|
std::tuple<torch::Tensor, torch::Tensor> non_first_view_op(
|
|
const torch::Tensor& self,
|
|
const torch::Tensor& other) {
|
|
// This is not allowed. We test below that this calling into the boxed kernel
|
|
// will raise an error
|
|
return std::tuple<torch::Tensor, torch::Tensor>(self.clone(), other);
|
|
}
|
|
|
|
int64_t ret_single_non_tensor(
|
|
const torch::Tensor& self,
|
|
const torch::Tensor& other) {
|
|
return 12;
|
|
}
|
|
|
|
torch::Tensor opt_op(
|
|
const torch::Tensor& self,
|
|
const std::optional<at::Tensor>& other) {
|
|
if (other.has_value()) {
|
|
return self + other.value();
|
|
} else {
|
|
return self.clone();
|
|
}
|
|
}
|
|
|
|
torch::Tensor my_custom_op(
|
|
const torch::Tensor& self,
|
|
const torch::Tensor& other) {
|
|
return self + other;
|
|
}
|
|
|
|
std::tuple<torch::Tensor, torch::Tensor, int64_t> ret_tuple_non_tensor(
|
|
const torch::Tensor& self,
|
|
const torch::Tensor& other) {
|
|
auto a = self - other;
|
|
auto b = self + other;
|
|
return std::tuple<torch::Tensor, torch::Tensor, int64_t>(a, b, 12);
|
|
}
|
|
|
|
torch::Tensor view_op(const torch::Tensor& self) {
|
|
return self.alias();
|
|
}
|
|
|
|
torch::Tensor view_op_with_extra_arg(
|
|
const torch::Tensor& self,
|
|
const torch::Tensor& other) {
|
|
return self.alias();
|
|
}
|
|
|
|
std::vector<torch::Tensor> ret_tensor_vector_view(
|
|
const torch::Tensor& self,
|
|
const torch::Tensor& other) {
|
|
return {self.alias(), self.alias()};
|
|
}
|
|
|
|
std::vector<at::Tensor> ret_tensor_vector(
|
|
const torch::Tensor& self,
|
|
const torch::Tensor& other) {
|
|
std::vector<at::Tensor> out;
|
|
out.push_back(self + other);
|
|
out.push_back(self - other);
|
|
return out;
|
|
}
|
|
|
|
torch::Tensor tensorlist_op(const torch::Tensor& self, at::TensorList other) {
|
|
const auto& res = self.clone();
|
|
for (const auto& t : other) {
|
|
res.add_(t);
|
|
}
|
|
return res;
|
|
}
|
|
|
|
#define REGISTER_TEST_OP(name, schema, fn) \
|
|
auto m = MAKE_TORCH_LIBRARY(_test); \
|
|
m.def(schema); \
|
|
auto m_autograd = MAKE_TORCH_LIBRARY_IMPL(_test, Autograd); \
|
|
auto m_cpu = MAKE_TORCH_LIBRARY_IMPL(_test, CPU); \
|
|
auto m_inplaceorview = MAKE_TORCH_LIBRARY_IMPL(_test, ADInplaceOrView); \
|
|
m_cpu.impl(name, c10::DispatchKey::CPU, TORCH_FN(fn)); \
|
|
m_autograd.impl( \
|
|
name, c10::DispatchKey::Autograd, autogradNotImplementedFallback()); \
|
|
m_inplaceorview.impl( \
|
|
name, \
|
|
c10::DispatchKey::ADInplaceOrView, \
|
|
autogradNotImplementedInplaceOrViewFallback());
|
|
|
|
template <typename F>
|
|
void assertBasicChecks(F op) {
|
|
auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
|
|
auto b = torch::tensor({1.}, {torch::kFloat32});
|
|
auto c = torch::tensor({1.}, {torch::kFloat32});
|
|
|
|
// If any inputs require grad,
|
|
auto out1 = op(a, b);
|
|
ASSERT_THROWS_WITH(out1.backward(), "is not implemented");
|
|
|
|
// # Should not have grad_fn if none require grad
|
|
auto out2 = op(b, c);
|
|
ASSERT_THROWS_WITH(
|
|
out2.backward(),
|
|
"element 0 of tensors does not require grad and does not have a grad_fn");
|
|
|
|
// TODO: Forward AD Tests?
|
|
}
|
|
|
|
} // namespace
|
|
|
|
TEST(TestAutogradNotImplementedFallback, RetSingleNonTensor) {
|
|
REGISTER_TEST_OP(
|
|
"ret_single_non_tensor",
|
|
"_test::ret_single_non_tensor(Tensor self, Tensor other) -> int",
|
|
ret_single_non_tensor);
|
|
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
|
|
"_test::ret_single_non_tensor", "");
|
|
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
|
|
return callOpUnboxed<int64_t, const torch::Tensor&, const torch::Tensor&>(
|
|
opHandle, _1, _2);
|
|
};
|
|
|
|
auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
|
|
auto b = torch::tensor({1.}, {torch::kFloat32});
|
|
|
|
ASSERT_EQ(op(a, b), ret_single_non_tensor(a, b));
|
|
}
|
|
|
|
TEST(TestAutogradNotImplementedFallback, InplaceOp) {
|
|
REGISTER_TEST_OP(
|
|
"inplace_op",
|
|
"_test::inplace_op(Tensor(a!) self, Tensor other) -> Tensor(a!)",
|
|
inplace_op);
|
|
auto opHandle =
|
|
c10::Dispatcher::singleton().findSchemaOrThrow("_test::inplace_op", "");
|
|
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
|
|
return callOpUnboxed<
|
|
torch::Tensor,
|
|
const torch::Tensor&,
|
|
const torch::Tensor&>(opHandle, _1, _2);
|
|
};
|
|
|
|
auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
|
|
auto b = torch::tensor({1.}, {torch::kFloat32});
|
|
|
|
// Check in-place
|
|
ASSERT_THROWS_WITH(
|
|
op(a, b),
|
|
"a leaf Variable that requires grad is being used in an in-place operation");
|
|
op(b, a);
|
|
a = a.clone();
|
|
b = b.clone();
|
|
auto c = op(a, b);
|
|
ASSERT_TRUE(torch::allclose(c, inplace_op(a, b)));
|
|
|
|
// Test in-place on view
|
|
auto base =
|
|
torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone();
|
|
auto view = base.view(-1);
|
|
auto t = torch::tensor({1.}, {torch::kFloat32});
|
|
|
|
torch::Tensor v_nograd;
|
|
{
|
|
c10::NoGradGuard guard;
|
|
v_nograd = base.view(-1);
|
|
op(v_nograd, t);
|
|
}
|
|
|
|
ASSERT_THROWS_WITH(op(v_nograd, t), "A view was created in no_grad mode");
|
|
ASSERT_EQ(op(view, t).unsafeGetTensorImpl(), view.unsafeGetTensorImpl());
|
|
ASSERT_THAT(
|
|
op(view, t).grad_fn()->name(), ::testing::HasSubstr("AsStridedBackward"));
|
|
}
|
|
|
|
TEST(TestAutogradNotImplementedFallback, DoubleInplaceOp) {
|
|
REGISTER_TEST_OP(
|
|
"two_arg_inplace_op",
|
|
"_test::two_arg_inplace_op(Tensor(a!) self, Tensor(b!) other) -> (Tensor(a!), Tensor(b!))",
|
|
two_arg_inplace_op);
|
|
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
|
|
"_test::two_arg_inplace_op", "");
|
|
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
|
|
return callOpUnboxed<
|
|
std::tuple<torch::Tensor, torch::Tensor>,
|
|
const torch::Tensor&,
|
|
const torch::Tensor&>(opHandle, _1, _2);
|
|
};
|
|
auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
|
|
auto b = torch::tensor({1.}, {torch::kFloat32});
|
|
|
|
// Both are modified in-place!
|
|
ASSERT_THROWS_WITH(
|
|
op(a, b),
|
|
"a leaf Variable that requires grad is being used in an in-place operation");
|
|
ASSERT_THROWS_WITH(
|
|
op(b, a),
|
|
"a leaf Variable that requires grad is being used in an in-place operation");
|
|
|
|
auto c =
|
|
torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone();
|
|
auto d =
|
|
torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone();
|
|
|
|
auto saved_version_c = c._version();
|
|
auto saved_version_d = d._version();
|
|
op(c, d);
|
|
ASSERT_NE(c._version(), saved_version_c);
|
|
ASSERT_NE(d._version(), saved_version_d);
|
|
}
|
|
|
|
TEST(TestAutogradNotImplementedFallback, OptOp) {
|
|
REGISTER_TEST_OP(
|
|
"opt_op", "_test::opt_op(Tensor self, Tensor? other) -> Tensor", opt_op);
|
|
auto opHandle =
|
|
c10::Dispatcher::singleton().findSchemaOrThrow("_test::opt_op", "");
|
|
auto op = [&](const torch::Tensor& _1,
|
|
const std::optional<torch::Tensor>& _2) {
|
|
return callOpUnboxed<
|
|
torch::Tensor,
|
|
const torch::Tensor&,
|
|
const std::optional<torch::Tensor>&>(opHandle, _1, _2);
|
|
};
|
|
|
|
auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
|
|
auto b = torch::tensor({1.}, {torch::kFloat32});
|
|
|
|
ASSERT_TRUE(torch::allclose(op(a, b), opt_op(a, b)));
|
|
ASSERT_TRUE(torch::allclose(op(a, {}), opt_op(a, {})));
|
|
}
|
|
|
|
TEST(TestAutogradNotImplementedFallback, OutOfPlaceAddition) {
|
|
REGISTER_TEST_OP(
|
|
"my_custom_op",
|
|
"_test::my_custom_op(Tensor self, Tensor other) -> Tensor",
|
|
my_custom_op);
|
|
auto opHandle =
|
|
c10::Dispatcher::singleton().findSchemaOrThrow("_test::my_custom_op", "");
|
|
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
|
|
return callOpUnboxed<
|
|
torch::Tensor,
|
|
const torch::Tensor&,
|
|
const torch::Tensor&>(opHandle, _1, _2);
|
|
};
|
|
|
|
assertBasicChecks(op);
|
|
}
|
|
|
|
TEST(TestAutogradNotImplementedFallback, RetTupleNonTensor) {
|
|
REGISTER_TEST_OP(
|
|
"ret_tuple_non_tensor",
|
|
"_test::ret_tuple_non_tensor(Tensor self, Tensor other) -> (Tensor, Tensor, int)",
|
|
ret_tuple_non_tensor);
|
|
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
|
|
"_test::ret_tuple_non_tensor", "");
|
|
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
|
|
auto out = callOpUnboxed<
|
|
std::tuple<torch::Tensor, torch::Tensor, int64_t>,
|
|
const torch::Tensor&,
|
|
const torch::Tensor&>(opHandle, _1, _2);
|
|
auto [out0, out1, out2] = std::move(out);
|
|
return out0;
|
|
};
|
|
|
|
assertBasicChecks(op);
|
|
}
|
|
|
|
TEST(TestAutogradNotImplementedFallback, ViewOp) {
|
|
REGISTER_TEST_OP(
|
|
"view_op", "_test::view_op(Tensor(a) self) -> Tensor(a)", view_op);
|
|
auto opHandle =
|
|
c10::Dispatcher::singleton().findSchemaOrThrow("_test::view_op", "");
|
|
auto op = [&](const torch::Tensor& _1) {
|
|
return callOpUnboxed<torch::Tensor, const torch::Tensor&>(opHandle, _1);
|
|
};
|
|
auto b = torch::tensor({1.}, {torch::kFloat32});
|
|
auto v = op(b);
|
|
ASSERT_TRUE(v.is_view());
|
|
ASSERT_EQ(v._base().unsafeGetTensorImpl(), b.unsafeGetTensorImpl());
|
|
|
|
auto b1 =
|
|
torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone();
|
|
auto v1 = op(b1);
|
|
ASSERT_TRUE(v1.is_view());
|
|
ASSERT_EQ(v1._base().unsafeGetTensorImpl(), b1.unsafeGetTensorImpl());
|
|
|
|
// Test inplace on view
|
|
auto t = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
|
|
|
|
// raise on rebase_history when it refreshes grad_fn
|
|
ASSERT_THROWS_WITH(
|
|
v1.add_(t), "which does not have a derivative implemented is forbidden");
|
|
// base should not be aware of the views, so this is still okay
|
|
b1.add_(t);
|
|
ASSERT_THROWS_WITH(
|
|
v1.grad_fn(),
|
|
"which does not have a derivative implemented is forbidden");
|
|
}
|
|
|
|
TEST(TestAutogradNotImplementedFallback, ViewOpWithExtraArg) {
|
|
REGISTER_TEST_OP(
|
|
"view_op_with_extra_arg",
|
|
"_test::view_op_with_extra_arg(Tensor(a) self, Tensor other) -> Tensor(a)",
|
|
view_op_with_extra_arg);
|
|
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
|
|
"_test::view_op_with_extra_arg", "");
|
|
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
|
|
return callOpUnboxed<
|
|
torch::Tensor,
|
|
const torch::Tensor&,
|
|
const torch::Tensor&>(opHandle, _1, _2);
|
|
};
|
|
assertBasicChecks(op);
|
|
auto a = torch::tensor({1.}, {torch::kFloat32});
|
|
auto b = torch::tensor({2.}, {torch::kFloat32});
|
|
auto out1 = op(a, b);
|
|
ASSERT_TRUE(out1.is_view());
|
|
ASSERT_EQ(out1._base().unsafeGetTensorImpl(), a.unsafeGetTensorImpl());
|
|
}
|
|
|
|
TEST(TestAutogradNotImplementedFallback, RetTensorVectorView) {
|
|
REGISTER_TEST_OP(
|
|
"ret_tensor_vector_view",
|
|
"_test::ret_tensor_vector_view(Tensor(a) self, Tensor other) -> Tensor[](a)",
|
|
ret_tensor_vector_view);
|
|
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
|
|
"_test::ret_tensor_vector_view", "");
|
|
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
|
|
return callOpUnboxed<
|
|
std::vector<at::Tensor>,
|
|
const torch::Tensor&,
|
|
const torch::Tensor&>(opHandle, _1, _2);
|
|
};
|
|
auto a = torch::tensor({1.}, {torch::kFloat32});
|
|
auto b = torch::tensor({1.}, {torch::kFloat32});
|
|
auto out = op(a, b);
|
|
ASSERT_TRUE(out[0].is_view());
|
|
ASSERT_EQ(out[0]._base().unsafeGetTensorImpl(), a.unsafeGetTensorImpl());
|
|
ASSERT_TRUE(out[1].is_view());
|
|
ASSERT_EQ(out[1]._base().unsafeGetTensorImpl(), a.unsafeGetTensorImpl());
|
|
}
|
|
|
|
TEST(TestAutogradNotImplementedFallback, DoubleViewOP) {
|
|
REGISTER_TEST_OP(
|
|
"two_pairs_of_view_op",
|
|
"_test::two_pairs_of_view_op(Tensor(a) self, Tensor(b) other) -> (Tensor(a), Tensor(b))",
|
|
two_pairs_of_view_op);
|
|
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
|
|
"_test::two_pairs_of_view_op", "");
|
|
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
|
|
return callOpUnboxed<
|
|
std::tuple<torch::Tensor, torch::Tensor>,
|
|
const torch::Tensor&,
|
|
const torch::Tensor&>(opHandle, _1, _2);
|
|
};
|
|
auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
|
|
auto b = torch::tensor({1.}, {torch::kFloat32});
|
|
ASSERT_THROWS_WITH(
|
|
op(a, b),
|
|
"Expected only a single output in the operator schema to have a non-write alias annotation");
|
|
}
|
|
|
|
TEST(TestAutogradNotImplementedFallback, NonFirstViewOP) {
|
|
REGISTER_TEST_OP(
|
|
"non_first_view_op",
|
|
"_test::non_first_view_op(Tensor self, Tensor(b) other) -> (Tensor, Tensor(b))",
|
|
non_first_view_op);
|
|
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
|
|
"_test::non_first_view_op", "");
|
|
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
|
|
return callOpUnboxed<
|
|
std::tuple<torch::Tensor, torch::Tensor>,
|
|
const torch::Tensor&,
|
|
const torch::Tensor&>(opHandle, _1, _2);
|
|
};
|
|
auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
|
|
auto b = torch::tensor({1.}, {torch::kFloat32});
|
|
ASSERT_THROWS_WITH(
|
|
op(a, b), "can only create view relationships between the first");
|
|
}
|
|
|
|
TEST(TestAutogradNotImplementedFallback, RetTensorVector) {
|
|
REGISTER_TEST_OP(
|
|
"ret_tensor_vector",
|
|
"_test::ret_tensor_vector(Tensor self, Tensor other) -> Tensor[]",
|
|
ret_tensor_vector);
|
|
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
|
|
"_test::ret_tensor_vector", "");
|
|
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
|
|
return callOpUnboxed<
|
|
std::vector<at::Tensor>,
|
|
const torch::Tensor&,
|
|
const torch::Tensor&>(opHandle, _1, _2)[0];
|
|
};
|
|
assertBasicChecks(op);
|
|
}
|
|
|
|
TEST(TestAutogradNotImplementedFallback, TensorlistOp) {
|
|
REGISTER_TEST_OP(
|
|
"tensorlist_op",
|
|
"_test::tensorlist_op(Tensor self, Tensor[] other) -> Tensor",
|
|
tensorlist_op);
|
|
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
|
|
"_test::tensorlist_op", "");
|
|
auto op = [&](torch::Tensor _1, at::TensorList _2) {
|
|
return callOpUnboxed<torch::Tensor, const torch::Tensor&, at::TensorList>(
|
|
opHandle, _1, _2);
|
|
};
|
|
|
|
auto a = torch::tensor({1.}, {torch::kFloat32});
|
|
auto b = torch::tensor({1.}, {torch::kFloat32});
|
|
auto c = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
|
|
std::vector<torch::Tensor> vec = {b, c};
|
|
auto out = op(a, vec);
|
|
|
|
ASSERT_THROWS_WITH(
|
|
torch::autograd::grad({out}, {vec[0]}),
|
|
"element 0 of the input tensors does not require grad");
|
|
ASSERT_THROWS_WITH(
|
|
torch::autograd::grad({out}, {vec[1]}), "is not implemented");
|
|
|
|
ASSERT_TRUE(at::allclose(op(a, vec), tensorlist_op(a, vec)));
|
|
}
|
|
|
|
static std::string test_format_error(const std::string& s) {
|
|
return s;
|
|
}
|
|
|
|
TEST(TestAutogradUtils, ValidateOutputsReduce) {
|
|
auto input = torch::ones({}, {torch::kFloat32});
|
|
auto grad = torch::ones({2, 3}, {torch::kFloat32});
|
|
|
|
std::vector<std::optional<InputMetadata>> input_metadata;
|
|
input_metadata.emplace_back(InputMetadata(input));
|
|
std::vector<torch::Tensor> grads;
|
|
grads.emplace_back(grad);
|
|
|
|
torch::autograd::validate_outputs(input_metadata, grads, test_format_error);
|
|
ASSERT_TRUE(at::allclose(grads[0], grad.sum()));
|
|
}
|
|
|
|
TEST(TestAutogradUtils, ValidateOutputsBasic) {
|
|
auto input = torch::zeros({2, 3}, {torch::kFloat32});
|
|
auto grad = torch::ones({2, 3}, {torch::kFloat32});
|
|
|
|
std::vector<std::optional<InputMetadata>> input_metadata;
|
|
input_metadata.emplace_back(InputMetadata(input));
|
|
std::vector<torch::Tensor> grads;
|
|
grads.emplace_back(grad);
|
|
|
|
torch::autograd::validate_outputs(input_metadata, grads, test_format_error);
|
|
ASSERT_TRUE(at::allclose(grad, torch::ones({2, 3})));
|
|
}
|
|
|
|
// TODO add these tests if needed
|
|
// test_once_differentiable
|
|
// test_sparse_backward
|
|
// test_save_output_nr
|
|
// test_free_deep_graph_pyfunction
|
|
// test_naughty_anomaly_access
|
|
// test_naughty_autograd-function_stashing_ctx
|
|
// test_custom_autograd_repeated_grad_grad
|
|
// test_return_leaf
|
|
// test_anomaly_detect_nan
|
|
// test_no_grad_copy
|