#include #include #include #include #include #include using namespace torch::autograd; using namespace torch::test; #define ASSERT_VARIABLE_EQ(a,b) ASSERT_TRUE(torch::allclose((a),(b))) #define EXPECT_VARIABLE_EQ(a,b) EXPECT_TRUE(torch::allclose((a),(b))) std::string graph_desc(std::shared_ptr node) { if (!node) { return "None"; } auto result = node->name() + "("; auto next_edges = node->next_edges(); for(auto& edge : next_edges) { result += graph_desc(edge.function); } return result+")"; } Variable simple_fn(const Variable& x, const Variable& y) { return x + 2 * y + x * y; } TEST(AutogradAPITests, BackwardSimpleTest) { Variable x = torch::randn({2, 2}, torch::requires_grad()); Variable y = torch::randn({2, 2}, torch::requires_grad()); auto res = simple_fn(x, y); backward({res.sum()}, {}); ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({2, 2})); ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({2, 2})*2); } TEST(AutogradAPITests, BackwardTest) { Variable x = torch::randn({2, 2}, torch::requires_grad()); Variable y = torch::randn({2, 2}, torch::requires_grad()); auto res = simple_fn(x, y); backward({res}, {torch::ones({2, 2})}, {}, true); backward({res}, {torch::ones({2, 2})}); ASSERT_VARIABLE_EQ(x.grad(), 2* (y + torch::ones({2, 2}))); ASSERT_VARIABLE_EQ(y.grad(), 2 * (x + torch::ones({2, 2})*2)); } TEST(AutogradAPITests, GradSimpleTest) { // basic grad Variable x = torch::randn({2,2}, torch::requires_grad()); Variable y = torch::randn({2,2}, torch::requires_grad()); auto res = simple_fn(x, y); auto grad_res = grad({res}, {x, y}, {torch::ones({2, 2})}); ASSERT_VARIABLE_EQ(grad_res[0], y + torch::ones({2, 2})); ASSERT_VARIABLE_EQ(grad_res[1], x + torch::ones({2, 2}) * 2); } TEST(AutogradAPITests, GradTest) { Variable x = torch::randn({2, 2}, torch::requires_grad()); Variable y = torch::randn({2, 2}, torch::requires_grad()); auto res = simple_fn(x, y); res.backward(torch::ones({2, 2}), false, true); Variable x_grad = y + torch::ones({2, 2}); Variable y_grad = x + torch::ones({2, 2}) * 2; ASSERT_VARIABLE_EQ(x.grad(), x_grad); ASSERT_VARIABLE_EQ(y.grad(), y_grad); Variable grad_sum = 2 * x.grad() + y.grad(); auto x_hv = grad({grad_sum}, {x}, {torch::ones({2, 2})}, {}, true); ASSERT_VARIABLE_EQ(x_hv[0], torch::ones({2, 2})); ASSERT_VARIABLE_EQ(x.grad(), x_grad); ASSERT_VARIABLE_EQ(y.grad(), y_grad); } TEST(AutogradAPITests, GradNonLeafTest) { Variable x_init = torch::randn({2, 2}, torch::requires_grad()); Variable x = x_init; Variable y = torch::randn({2, 2}, torch::requires_grad()); Variable grad_output = torch::ones({2, 2}); for (int i = 0; i < 5; ++ i) { auto res = simple_fn(x, y); auto input_grads = grad({res}, {x}, {grad_output}, {}, true); Variable grad_x_expected = y + torch::ones({2, 2}); ASSERT_VARIABLE_EQ(input_grads[0], grad_x_expected); ASSERT_FALSE(x.grad().defined()); ASSERT_FALSE(y.grad().defined()); x = x + 0.05 * input_grads[0]; } float val_init = simple_fn(x_init, y).sum().item().toFloat(); float val_final = simple_fn(x, y).sum().item().toFloat(); ASSERT_TRUE(val_final > val_init); x.backward(grad_output, false, true); ASSERT_TRUE(x_init.grad().defined()); ASSERT_TRUE(y.grad().defined()); } TEST(AutogradAPITests, GradUnreachableTest) { Variable x = torch::ones({1}, torch::requires_grad()); Variable y = torch::ones({1}, torch::requires_grad()); Variable z = x * 2; Variable w = y * 2; auto grad_res = grad({x * 2}, {x, y}, {}, {}, false, true); ASSERT_VARIABLE_EQ(grad_res[0], x * 2); ASSERT_FALSE(grad_res[1].defined()); // This is slightly different than the case above, because z doesn't even // have a grad accumulator allocated. z = torch::ones({1}, torch::requires_grad()); grad_res = grad({x * 2}, {x, z}, {}, {}, false, true); ASSERT_VARIABLE_EQ(grad_res[0], x * 2); ASSERT_FALSE(grad_res[1].defined()); // allow_unused=False, but grads contains None inside, should throw ASSERT_THROWS_WITH(grad({x * 2}, {x, y}, {}, {}, false, false), "Set allow_unused=True"); } TEST(CustomAutogradTest, GradUnreachableDiscoveryTest) { // Test that certain nodes are not erroneously executed when an input // is unreachable. See #39784 struct MyFunction : public Function { static Variable forward(AutogradContext *ctx, Variable var) { return var; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { ADD_FAILURE() << "This node should not be executed!"; return grad_output; } }; auto x = torch::randn(1, torch::requires_grad()); auto x1 = torch::randn(1); auto x2 = MyFunction::apply(x + x1); auto y = torch::randn(1, torch::requires_grad()); auto grad_res = torch::autograd::grad({x2}, {y}, {}, {}, false, true); ASSERT_FALSE(grad_res[0].defined()); } TEST(AutogradAPITests, EmptyInput) { Variable x = torch::ones({1}, torch::requires_grad()); ASSERT_THROWS_WITH(grad({x * 2}, /*inputs=*/{}, {x}), "grad requires non-empty inputs."); } TEST(AutogradAPITests, RetainGrad) { auto input = torch::rand({1, 3}, torch::requires_grad()); auto h1 = input * 3; auto out = (h1 * h1).sum(); { // Warning when grad is accessed for non-leaf tensor WarningCapture warnings; ASSERT_FALSE(h1.grad().defined()); ASSERT_TRUE( warnings.str().find("is not a leaf") != std::string::npos); } // It should be possible to call retain_grad() multiple times h1.retain_grad(); h1.retain_grad(); { // If retain_grad is true for a non-leaf tensor, // there should not be any warning when grad is accessed WarningCapture warnings; ASSERT_FALSE(h1.grad().defined()); ASSERT_FALSE( warnings.str().find("is not a leaf") != std::string::npos); } // Gradient should be accumulated // NOLINTNEXTLINE(bugprone-argument-comment) out.backward({}, /*keep_graph=*/true); ASSERT_VARIABLE_EQ(h1 * 2, h1.grad()); // NOLINTNEXTLINE(bugprone-argument-comment) out.backward({}, /*keep_graph=*/true); ASSERT_VARIABLE_EQ(h1 * 4, h1.grad()); { torch::NoGradGuard no_grad; input.grad().zero_(); } // It should be a no-op for leaves input.retain_grad(); input.retain_grad(); out.backward(); ASSERT_VARIABLE_EQ(input * 18, input.grad()); } TEST(AutogradAPITests, AnomalyMode) { // Needs to have backtrace as warning and then throw an error torch::autograd::DetectAnomalyGuard detect_anomaly; { WarningCapture warnings; auto x = torch::tensor({5.0}, torch::requires_grad()); auto y = x * x; auto z = y * y; y += 1; ASSERT_THROWS_WITH(z.backward(), "inplace"); ASSERT_TRUE( warnings.str().find("Traceback of forward") != std::string::npos); } { WarningCapture warnings; // Double backward auto x = torch::tensor({0.0}, torch::requires_grad()); auto y = x.pow(1.5); auto gr = // NOLINTNEXTLINE(bugprone-argument-comment) grad({y}, {x}, {}, /*retain_graph=*/true, /*create_backward=*/true); ASSERT_THROWS_WITH(grad({gr[0]}, {x}, {torch::tensor({0.0})});, "returned nan"); auto msgs = warnings.messages(); ASSERT_EQ(msgs.size(), 2); ASSERT_TRUE( msgs[0].find("Traceback of forward call that caused the error") != std::string::npos); ASSERT_TRUE( msgs[1].find( "Traceback of forward call that induced the previous calculation") != std::string::npos); } } TEST(CustomAutogradTest, CustomFunction) { struct MyFunction : public Function { static Variable forward(AutogradContext *ctx, Variable var1, int mul, Variable var2) { ctx->saved_data["mul"] = mul; ctx->save_for_backward({var1, var2}); return var1 + mul*var2 + var1*var2; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { int mul = ctx->saved_data["mul"].toInt(); auto saved = ctx->get_saved_variables(); auto var1 = saved[0]; auto var2 = saved[1]; variable_list output = {grad_output[0] + grad_output[0]*var2, Variable(), grad_output[0] * mul + grad_output[0] * var1}; return output; } }; Variable x = torch::randn({5,5}, torch::requires_grad()); Variable y = torch::randn({5,5}, torch::requires_grad()); auto res = MyFunction::apply(x,2,y); auto go = torch::ones({}, torch::requires_grad()); res.sum().backward(go, false, true); ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({5,5})); ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({5,5})*2); } TEST(CustomAutogradTest, FunctionReturnsInput) { struct MyFunction : public Function { static Variable forward(AutogradContext *ctx, Variable var1) { return var1; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { return {grad_output[0]*2}; } }; Variable x(torch::ones(1, torch::requires_grad())); MyFunction::apply(x).backward(torch::ones(1) , true, true); ASSERT_VARIABLE_EQ(x.grad(), torch::full(1, 2.)); } TEST(CustomAutogradTest, FunctionReturnsUndefined) { struct MyFunction : public Function { static Variable forward(AutogradContext *ctx, Variable var) { return var * 2; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { at::Tensor undefined_tensor; return {undefined_tensor}; } }; auto x = torch::ones(1, torch::requires_grad()); MyFunction::apply(x).backward(); ASSERT_FALSE(x.grad().defined()); MyFunction::apply(x.pow(2)).backward(); ASSERT_FALSE(x.grad().defined()); MyFunction::apply(x).sum().backward(); ASSERT_FALSE(x.grad().defined()); ASSERT_FALSE(torch::autograd::grad( {MyFunction::apply(x)}, {x}, {}, false, false, true)[0].defined()); } TEST(CustomAutogradTest, MaterializeGrads) { struct MyFunction : public Function { static Variable forward(AutogradContext *ctx, Variable var) { return var; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { EXPECT_VARIABLE_EQ(grad_output[0], torch::zeros(1)); return grad_output; } }; auto x = torch::ones(1, torch::requires_grad()); UndefinedGrad().apply({MyFunction::apply(x)})[0].backward(); } TEST(CustomAutogradTest, DontMaterializeGrads) { struct MyFunction : public Function { static Variable forward(AutogradContext *ctx, Variable var) { ctx->set_materialize_grads(false); return var; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { EXPECT_FALSE(grad_output[0].defined()); return grad_output; } }; auto x = torch::ones(1, torch::requires_grad()); UndefinedGrad().apply({MyFunction::apply(x)})[0].backward(); } TEST(CustomAutogradTest, NoGradCustomFunction) { // Custom Function should respect grad mode struct MyOp : public Function { static Variable forward(AutogradContext *ctx, Variable x) { return x+1; } static variable_list backward(AutogradContext *ctx, variable_list dy) { return dy; } }; auto x = torch::ones({5,5}, torch::requires_grad()); { at::NoGradGuard no_grad; auto y = MyOp::apply(x); ASSERT_FALSE(y.requires_grad()); } } TEST(CustomAutogradTest, MarkDirty) { struct MyFunction : public Function { static Variable forward(AutogradContext *ctx, Variable v) { // Change the value inplace auto v_data = v.data_ptr(); v_data[0] = 2; ctx->mark_dirty({v}); return v; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { return { (grad_output[0]*2.0) }; } }; // Clone here because modifying leafs inplace is not allowed auto x = torch::randn({5,5}, torch::requires_grad()).clone(); auto version_before = x._version(); auto out = MyFunction::apply(x); auto version_after = x._version(); ASSERT_TRUE(version_after >= (version_before + 1)); out.sum().backward(); } TEST(CustomAutogradTest, MarkNonDifferentiable) { struct MyFunction : public Function { static Variable forward(AutogradContext *ctx, Variable v) { Variable output = v > 0; ctx->mark_non_differentiable({output}); return output; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { return { (grad_output[0]*0.0) }; } }; auto x = torch::randn({5,5}, torch::requires_grad()); auto mask = MyFunction::apply(x); ASSERT_FALSE(mask.requires_grad()); auto y = x.masked_fill(mask, 0); y.sum().backward(); } TEST(CustomAutogradTest, MarkNonDifferentiableMixed) { struct MyFunction : public Function { static variable_list forward(AutogradContext *ctx, Variable input) { Variable a = input+1; Variable b = input+2; ctx->mark_non_differentiable({a}); return {a,b}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { const Variable &grad_a = grad_output[0], &grad_b = grad_output[1]; EXPECT_VARIABLE_EQ(grad_a, torch::zeros({5,5})); EXPECT_VARIABLE_EQ(grad_b, torch::ones({5,5})); return {grad_b}; } }; auto x = torch::randn({5,5}, torch::requires_grad()); auto out = MyFunction::apply(x); ASSERT_FALSE(out[0].requires_grad()); ASSERT_TRUE(out[1].requires_grad()); out[1].sum().backward(); ASSERT_VARIABLE_EQ(x.grad(), torch::ones({5,5})); } TEST(CustomAutogradTest, MarkNonDifferentiableNone) { struct MyFunction : public Function { static Variable forward(AutogradContext *ctx, Variable input) { auto output = input.clone(); ctx->mark_non_differentiable({output}); return output; } static variable_list backward(AutogradContext *ctx, variable_list grad_outputs) { return {}; } }; auto x = torch::randn({5,5}, torch::requires_grad()); auto r = MyFunction::apply(x * x); (r * x).sum().backward(); } TEST(CustomAutogradTest, ReturnLeafInplace) { struct Inplace : public Function { static variable_list forward(AutogradContext *ctx, Variable a, Variable b) { ctx->mark_dirty({a}); return {a.add_(b), b+2}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { return {grad_output[0], grad_output[0] + grad_output[1]}; } }; Variable x = torch::randn({5,5}); Variable y = torch::randn({5,5}, torch::requires_grad()); auto out = Inplace::apply(x,y); auto &q = out[0]; ASSERT_TRUE(torch::equal(q, x)); ASSERT_TRUE(q.requires_grad()); q.sum().backward(); ASSERT_VARIABLE_EQ(y.grad(), torch::ones({5,5})); } TEST(CustomAutogradTest, ReturnDuplicateInplace) { struct DoubleInplace : public Function { static variable_list forward(AutogradContext *ctx, Variable x) { x.mul_(2); ctx->mark_dirty({x}); return {x,x}; } static variable_list backward(AutogradContext *ctsx, variable_list grad_outputs) { return {grad_outputs[0]*2 + grad_outputs[1]*2}; } }; auto x = torch::randn({5,5}, torch::requires_grad()); ASSERT_THROWS_WITH(DoubleInplace::apply(x), "leaf Variable that requires grad"); // TODO ASSERT_THROWS_WITH(DoubleInplace::apply(x.clone()[0]), "only one output"); auto out = DoubleInplace::apply(x.clone()); ASSERT_TRUE(torch::equal(out[0],out[1])); } TEST(CustomAutogradTest, ReturnDuplicate) { struct DoubleDuplicate : public Function { static variable_list forward(AutogradContext *ctx, Variable x) { auto output = x*2; return {output, output}; } static variable_list backward(AutogradContext *ctx, variable_list grad_outputs) { return {grad_outputs[0]*2 + grad_outputs[1]*2}; } }; auto x = torch::randn({5,5}, torch::requires_grad()); auto out = DoubleDuplicate::apply(x); ASSERT_TRUE(torch::equal(out[0],out[1])); } TEST(CustomAutogradTest, SaveEmptyForBackward) { struct MyFunction : public Function { static Variable forward(AutogradContext *ctx, Variable input) { ctx->save_for_backward({Variable(), input, Variable()}); return input*input; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { auto saved = ctx->get_saved_variables(); EXPECT_FALSE(saved[0].defined()); EXPECT_FALSE(saved[2].defined()); return {saved[1] * 2 * grad_output[0]}; } }; Variable x = torch::randn({5,5}, torch::requires_grad()); auto y = MyFunction::apply(x); y.sum().backward(); ASSERT_VARIABLE_EQ(x.grad(), 2*x); } TEST(CustomAutogradTest, InvalidGradients) { struct MyFunction : public Function { static Variable forward(AutogradContext *ctx, Variable x) { return x*2; } static variable_list backward(AutogradContext *ctsx, variable_list grad_outputs) { return {torch::randn(10, torch::dtype(torch::kFloat).requires_grad(true))}; } }; auto input1 = torch::randn({5,5}, torch::dtype(torch::kFloat).requires_grad(true)); ASSERT_THROWS_WITH( MyFunction::apply(input1).sum().backward(), "expected shape"); auto input2 = torch::randn(10, torch::dtype(torch::kDouble).requires_grad(true)); } TEST(CustomAutogradTest, NoGradInput) { struct MyFunction : public Function { static Variable forward(AutogradContext*, Variable x) { return x; } static variable_list backward(AutogradContext*, variable_list grad_outputs) { return grad_outputs; } }; Variable x = torch::randn({5,5}, torch::requires_grad()); Variable y; { at::NoGradGuard no_grad; y = MyFunction::apply(x); } ASSERT_TRUE(x.requires_grad()); ASSERT_FALSE(y.grad_fn()); } TEST(CustomAutogradTest, TooManyGrads) { struct MyFunction : public Function { static Variable forward(AutogradContext*, Variable input) { return input; } static variable_list backward(AutogradContext*, variable_list grad_output) { grad_output.insert(grad_output.end(), {Variable(), Variable()}); return grad_output; } }; } TEST(CustomAutogradTest, DepNoGrad) { struct F1 : public Function { static variable_list forward(AutogradContext *ctx, Variable input) { auto out = torch::randn(input.sizes()); ctx->mark_non_differentiable({out}); return {input, out}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { return {grad_output[0]}; } }; struct F2 : public Function { static Variable forward(AutogradContext*, Variable input, Variable ignore) { return input; } static variable_list backward(AutogradContext*, variable_list grad_output) { return {grad_output[0], Variable()}; } }; auto x = torch::randn(5, torch::requires_grad()); auto out = F1::apply(x); Variable &a = out[0], &b = out[1]; b = b+1; // Separate F1 and F2 by another operation ASSERT_TRUE(a.requires_grad()); ASSERT_FALSE(b.requires_grad()); auto c = F2::apply(a,b); c.backward(torch::ones(c.sizes()), false, false); ASSERT_VARIABLE_EQ(x.grad(), torch::ones(x.sizes())); } TEST(CustomAutogradTest, Reentrant) { static Variable y_data = torch::randn({2, 2}); struct Reenter : public Function { static Variable forward(AutogradContext *ctx, Variable input) { Variable output; { at::AutoGradMode enable_grad(true); auto x = make_variable(input.tensor_data(), true); auto y = make_variable(y_data.tensor_data(), true); output = x*y; ctx->saved_data["x"] = x; ctx->saved_data["y"] = y; ctx->saved_data["output_var"] = output; } return output.detach(); } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { { at::AutoGradMode enable_grad(true); auto out = ctx->saved_data["output_var"].toTensor(); out.sum().backward(); } return {ctx->saved_data["x"].toTensor().grad() * grad_output[0]}; } }; auto x = torch::randn({2,2}, torch::requires_grad()); auto out = Reenter::apply(x); out.sum().backward(); ASSERT_VARIABLE_EQ(x.grad(), y_data); } // NOTE: If this fails for apparently unrelated reasons in TSAN be aware of // the TSAN limit on mutex: https://github.com/google/sanitizers/issues/950 TEST(CustomAutogradTest, DeepReentrant) { struct DeepReenter : public Function { static Variable forward(AutogradContext *ctx, Variable x) { { at::AutoGradMode enable_grad(true); ctx->saved_data["x"] = make_variable(x.tensor_data(), true) -1; } return ctx->saved_data["x"].toTensor().detach(); } static variable_list backward(AutogradContext*ctx, variable_list grad_output) { if (!ctx->saved_data["x"].toTensor().is_nonzero()) { return grad_output; } { at::AutoGradMode enable_grad(true); apply(ctx->saved_data["x"].toTensor())[0].sum().backward(); return grad_output; } } }; // This should not stack overflow auto v = torch::tensor({8193}, torch::dtype(torch::kFloat).requires_grad(true)); DeepReenter::apply(v).sum().backward(); } TEST(CustomAutogradTest, ReentrantPriority) { static std::vector order; struct MyFunction : public Function { static Variable forward(AutogradContext*, Variable x) { return x; } static variable_list backward(AutogradContext*, variable_list grad) { order.push_back(0); return grad; } }; struct Reenter : public Function { static Variable forward(AutogradContext *ctx, Variable x) { { at::AutoGradMode enable_grad(true); ctx->saved_data["x"] = make_variable(x.tensor_data(), true) -1; } return ctx->saved_data["x"].toTensor().detach(); } static variable_list backward(AutogradContext*ctx, variable_list grad_output) { order.push_back(1); if (!ctx->saved_data["x"].toTensor().is_nonzero()) { return grad_output; } { at::AutoGradMode enable_grad(true); apply(ctx->saved_data["x"].toTensor())[0].sum().backward(); return grad_output; } } }; auto a = MyFunction::apply(torch::tensor({6}, torch::dtype(torch::kFloat).requires_grad(true))); auto b = Reenter::apply(torch::tensor({9}, torch::dtype(torch::kFloat).requires_grad(true))); auto v = a*b; v.backward(); // All the reentrant tasks should be prioritized over the MyFunction backward // task. ASSERT_EQ(order.size(), 10); ASSERT_EQ(std::count(order.begin(), order.end(), 1), 9); ASSERT_EQ(order.back(), 0); // Clear static variable in case test get executed in a loop order.clear(); } TEST(CustomAutogradTest, Hooks) { Variable x = torch::ones({5,5}, torch::requires_grad()); Variable y = torch::ones({5,5})*4; y.set_requires_grad(true); int counter = 0; std::function bw_hook([&counter](int inc, Variable grad){ counter += inc; }); Variable z = x * x + x * 2 + x * y + y; x.register_hook([&bw_hook](Variable grad){ bw_hook(0, grad); }); auto hook_1 = z.register_hook([&bw_hook](Variable grad){ bw_hook(1, grad); }); z.backward(torch::ones({5,5}), true, true); ASSERT_EQ(counter, 1); auto hook_2 = z.register_hook([&bw_hook](Variable grad){ bw_hook(2, grad); }); z.backward(torch::ones({5,5}), true, true); ASSERT_EQ(counter, 4); z.remove_hook(hook_2); z.backward(torch::ones({5,5}), true, true); ASSERT_EQ(counter, 5); std::function bw_hook_modify([](Variable grad){ return grad.mul(2); }); z.remove_hook(hook_1); z.register_hook(bw_hook_modify); y.grad().zero_(); z.backward(torch::ones({5,5}), true, false); ASSERT_VARIABLE_EQ(y.grad(), (x+1)*2); y.register_hook(bw_hook_modify); y.grad().zero_(); z.backward(torch::ones({5,5}), false, false); ASSERT_VARIABLE_EQ(y.grad(), (x+1)*4); ASSERT_THROWS_WITH(y.remove_hook(3), "Invalid index"); } TEST(CustomAutogradTest, HookNone) { struct NoneGradientFunction : public Function { static variable_list forward(AutogradContext *ctx, Variable x, Variable y) { return {x,y}; } static variable_list backward(AutogradContext *ctx, variable_list grad) { return {grad[0], Variable()}; } }; bool was_called = false; auto hook = ([&was_called](Variable grad){ ASSERT_TRUE(grad.defined()); was_called = true; }); auto x = torch::randn({5,5}, torch::requires_grad()); auto y = torch::randn({5,5}); auto out = NoneGradientFunction::apply(x,y); Variable rx = x[0], ry = x[1]; rx.register_hook(hook); ry.register_hook(hook); (rx+ry).sum().backward(); ASSERT_TRUE(was_called); } TEST(CustomAutogradTest, BackwardWithInputs) { Variable x = torch::randn({5,5}, torch::requires_grad()); Variable y = torch::randn({5,5}, torch::requires_grad()); Variable z = x * x + x * y + y * y; Variable x_grad_expected = 2 * x + y; Variable y_grad_expected = x + 2 * y; z.backward(torch::ones({5, 5}), false, false, {x}); ASSERT_VARIABLE_EQ(x.grad(), x_grad_expected); ASSERT_FALSE(y.grad().defined()); } TEST(CustomAutogradTest, BackwardWithEmptyInputs) { Variable x = torch::randn({5,5}, torch::requires_grad()); Variable y = torch::randn({5,5}, torch::requires_grad()); Variable z = x * x + x * y + y * y; Variable x_grad_expected = 2 * x + y; Variable y_grad_expected = x + 2 * y; ASSERT_THROWS_WITH(z.backward(torch::ones({5, 5}), false, false, std::vector{}), "cannot be empty"); } TEST(CustomAutogradTest, BackwardWithNonLeafInputs) { Variable x = torch::randn({5,5}, torch::requires_grad()); Variable y = torch::randn({5,5}, torch::requires_grad()); Variable z = x * x; Variable w = y * z + x * y + y * y; Variable x_grad_expected = 2 * x * y + y; Variable z_grad_expected = y; w.backward(torch::ones({5, 5}), false, false, std::vector{x, z}); ASSERT_VARIABLE_EQ(x.grad(), x_grad_expected); ASSERT_VARIABLE_EQ(z.grad(), z_grad_expected); ASSERT_FALSE(y.grad().defined()); } TEST(CustomAutogradTest, BackwardWithCreateGraphWarns) { c10::Warning::WarnAlways guard(true); torch::Tensor x = torch::randn({5,5}).set_requires_grad(true); auto z = x * x; { WarningCapture warnings; z.backward(torch::ones({5, 5}), c10::nullopt, true); ASSERT_TRUE( warnings.str().find("Using backward() with create_graph=True") != std::string::npos); } { WarningCapture warnings; torch::autograd::backward({z}, {torch::ones({5, 5})}, c10::nullopt, true); ASSERT_TRUE( warnings.str().find("Using backward() with create_graph=True") != std::string::npos); } } /** * Tests for AutogradNotImplementedFallback * - Check that we created the NotImplemented kernel when inputs require grad * but when no inputs require grad, we should not create this node * - check_inplace logic * - view ops (TODO: not an official view yet, update this once InplaceOrView kernel is landed) * - TODO: Tests for NDEBUG checks? * - tensorlist input and output * - multiple outputs / non-tensor output * - rebase_history vs set_history */ namespace { torch::Tensor inplace_op(const torch::Tensor& self, const torch::Tensor& other) { return self.add_(other); } std::tuple two_arg_inplace_op(const torch::Tensor& self, const torch::Tensor& other) { other.add_(self); self.add_(other); return std::tuple(self, other); } std::tuple two_pairs_of_view_op(const torch::Tensor& self, const torch::Tensor& other) { // This is not allowed. We test below that this calling into the boxed kernel will raise an error auto self_view = self.view(-1); auto other_view = other.view(-1); return std::tuple(self_view, other_view); } int64_t ret_single_non_tensor(const torch::Tensor& self, const torch::Tensor& other) { return 12; } torch::Tensor opt_op(const torch::Tensor& self, const c10::optional& other) { if (other.has_value()) { return self + other.value(); } else { return self.clone(); } } torch::Tensor my_custom_op(const torch::Tensor& self, const torch::Tensor& other) { return self + other; } std::tuple ret_tuple_non_tensor(const torch::Tensor& self, const torch::Tensor& other) { auto a = self - other; auto b = self + other; return std::tuple(a, b, 12); } torch::Tensor view_op(const torch::Tensor& self, const torch::Tensor& other) { return self.view(-1); } std::vector ret_tensor_vector(const torch::Tensor& self, const torch::Tensor& other) { std::vector out; out.push_back(self + other); out.push_back(self - other); return out; } torch::Tensor tensorlist_op(const torch::Tensor& self, at::TensorList other) { const auto& res = self.clone(); for (const auto& t : other) { res.add_(t); } return res; } #define REGISTER_TEST_OP(name, schema, fn) \ auto m = MAKE_TORCH_LIBRARY(_test); \ m.def(schema); \ auto m_autograd = MAKE_TORCH_LIBRARY_IMPL(_test, Autograd); \ auto m_cpu = MAKE_TORCH_LIBRARY_IMPL(_test, CPU); \ m_cpu.impl(name, c10::DispatchKey::CPU, TORCH_FN(fn)); \ m_autograd.impl(name, c10::DispatchKey::Autograd, autogradNotImplementedFallback()); template void assertBasicChecks(F op) { auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); auto b = torch::tensor({1.}, {torch::kFloat32}); auto c = torch::tensor({1.}, {torch::kFloat32}); // If any inputs require grad, auto out1 = op(a, b); ASSERT_THROWS_WITH(out1.backward(), "is not implemented"); // # Should not have grad_fn if none require grad auto out2 = op(b, c); ASSERT_THROWS_WITH(out2.backward(), "element 0 of tensors does not require grad and does not have a grad_fn"); // TODO: Forward AD Tests? } } // namespace // These tests trigger an MSVC bug in the internal arvr build // Reproduce with: buck build @arvr/mode/win/opt //xplat/caffe2:autograd_libtorch_test_ovrsource // It is probably caused by the lambda, see https://github.com/pytorch/pytorch/issues/48763 #if !defined(_MSC_VER) TEST(TestAutogradNotImplementedFallback, RetSingleNonTensor) { REGISTER_TEST_OP("ret_single_non_tensor", "_test::ret_single_non_tensor(Tensor self, Tensor other) -> int", ret_single_non_tensor); auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::ret_single_non_tensor", ""); auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { return callOpUnboxed(opHandle, _1, _2); }; auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); auto b = torch::tensor({1.}, {torch::kFloat32}); ASSERT_EQ(op(a, b), ret_single_non_tensor(a, b)); } TEST(TestAutogradNotImplementedFallback, DoubleViewOP) { REGISTER_TEST_OP("two_pairs_of_view_op", "_test::two_pairs_of_view_op(Tensor(a) self, Tensor(b) other) -> (Tensor(a), Tensor(b))", two_pairs_of_view_op); auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::two_pairs_of_view_op", ""); auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { return callOpUnboxed, const torch::Tensor&, const torch::Tensor&>(opHandle, _1, _2); }; auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); auto b = torch::tensor({1.}, {torch::kFloat32}); ASSERT_THROWS_WITH(op(a, b), "Expected only a single output in the operator schema to have a non-write alias annotation"); } TEST(TestAutogradNotImplementedFallback, InplaceOp) { REGISTER_TEST_OP("inplace_op", "_test::inplace_op(Tensor(a!) self, Tensor other) -> Tensor(a!)", inplace_op); auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::inplace_op", ""); auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { return callOpUnboxed(opHandle, _1, _2); }; auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); auto b = torch::tensor({1.}, {torch::kFloat32}); // Check in-place ASSERT_THROWS_WITH(op(a, b), "a leaf Variable that requires grad is being used in an in-place operation"); op(b, a); a = a.clone(); b = b.clone(); auto c = op(a, b); ASSERT_TRUE(torch::allclose(c, inplace_op(a, b))); // Test in-place on view auto base = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone(); auto view = base.view(-1); auto t = torch::tensor({1.}, {torch::kFloat32}); torch::Tensor v_nograd; { c10::NoGradGuard guard; v_nograd = base.view(-1); op(v_nograd, t); } ASSERT_THROWS_WITH(op(v_nograd, t), "A view was created in no_grad mode"); ASSERT_EQ(op(view, t).unsafeGetTensorImpl(), view.unsafeGetTensorImpl()); // TODO: once we have InplaceOrView kernel, renable this since version counter would actually // be incremented // ASSERT_THAT(op(view, t).grad_fn()->name(), ::testing::HasSubstr("AsStridedBackward")); } TEST(TestAutogradNotImplementedFallback, DoubleInplaceOp) { REGISTER_TEST_OP("two_arg_inplace_op", "_test::two_arg_inplace_op(Tensor(a!) self, Tensor(b!) other) -> (Tensor(a!), Tensor(b!))", two_arg_inplace_op); auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::two_arg_inplace_op", ""); auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { return callOpUnboxed, const torch::Tensor&, const torch::Tensor&>(opHandle, _1, _2); }; auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); auto b = torch::tensor({1.}, {torch::kFloat32}); // Both are modified in-place! ASSERT_THROWS_WITH(op(a, b), "a leaf Variable that requires grad is being used in an in-place operation"); ASSERT_THROWS_WITH(op(b, a), "a leaf Variable that requires grad is being used in an in-place operation"); } TEST(TestAutogradNotImplementedFallback, OptOp) { REGISTER_TEST_OP("opt_op", "_test::opt_op(Tensor self, Tensor? other) -> Tensor", opt_op); auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::opt_op", ""); auto op = [&](const torch::Tensor& _1, const c10::optional& _2) { return callOpUnboxed&>(opHandle, _1, _2); }; auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); auto b = torch::tensor({1.}, {torch::kFloat32}); ASSERT_TRUE(torch::allclose(op(a, b), opt_op(a, b))); ASSERT_TRUE(torch::allclose(op(a, {}), opt_op(a, {}))); } TEST(TestAutogradNotImplementedFallback, OutOfPlaceAddition) { REGISTER_TEST_OP("my_custom_op", "_test::my_custom_op(Tensor self, Tensor other) -> Tensor", my_custom_op); auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::my_custom_op", ""); auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { return callOpUnboxed(opHandle, _1, _2); }; assertBasicChecks(op); } TEST(TestAutogradNotImplementedFallback, RetTupleNonTensor) { REGISTER_TEST_OP("ret_tuple_non_tensor", "_test::ret_tuple_non_tensor(Tensor self, Tensor other) -> (Tensor, Tensor, int)", ret_tuple_non_tensor); auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::ret_tuple_non_tensor", ""); auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { torch::Tensor out0; torch::Tensor out1; int64_t out2; auto out = callOpUnboxed, const torch::Tensor&, const torch::Tensor&>(opHandle, _1, _2); std::tie(out0, out1, out2) = std::move(out); return out0; }; assertBasicChecks(op); } TEST(TestAutogradNotImplementedFallback, ViewOp) { REGISTER_TEST_OP("view_op", "_test::view_op(Tensor(a) self, Tensor other) -> Tensor(a)", view_op); auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::view_op", ""); auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { return callOpUnboxed(opHandle, _1, _2); }; assertBasicChecks(op); } TEST(TestAutogradNotImplementedFallback, RetTensorVector) { REGISTER_TEST_OP("ret_tensor_vector", "_test::ret_tensor_vector(Tensor self, Tensor other) -> Tensor[]", ret_tensor_vector); auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::ret_tensor_vector", ""); auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { return callOpUnboxed, const torch::Tensor&, const torch::Tensor&>(opHandle, _1, _2)[0]; }; assertBasicChecks(op); } TEST(TestAutogradNotImplementedFallback, TensorlistOp) { REGISTER_TEST_OP("tensorlist_op", "_test::tensorlist_op(Tensor self, Tensor[] other) -> Tensor", tensorlist_op); auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::tensorlist_op", ""); auto op = [&](torch::Tensor _1, at::TensorList _2) { return callOpUnboxed(opHandle, _1, _2); }; auto a = torch::tensor({1.}, {torch::kFloat32}); auto b = torch::tensor({1.}, {torch::kFloat32}); auto c = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); std::vector vec = {b, c}; auto out = op(a, vec); ASSERT_THROWS_WITH(torch::autograd::grad({out}, {vec[0]}), "One of the differentiated Tensors does not require grad"); ASSERT_THROWS_WITH(torch::autograd::grad({out}, {vec[1]}), "is not implemented"); ASSERT_TRUE(at::allclose(op(a, vec), tensorlist_op(a, vec))); } #endif // TODO add these tests if needed // test_once_differentiable // test_sparse_backward // test_save_output_nr // test_free_deep_graph_pyfunction // test_naughty_anomaly_access // test_naughty_autograd-function_stashing_ctx // test_custom_autograd_repeated_grad_grad // test_return_leaf // test_anomaly_detect_nan // test_no_grad_copy