mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65147 I think they trigger an MSVC bug per https://github.com/pytorch/pytorch/issues/48763 ghstack-source-id: 138247203 Test Plan: breakpointed https://www.internalfb.com/intern/sandcastle/job/9007199738584981/ and sush'ed into the host and ran `buck build arvr/mode/win/opt //xplat/caffe2:autograd_libtorch_test_ovrsource` in `/cygdrive/d/ovrsource-null-hg` Reviewed By: soulitzer Differential Revision: D30992685 fbshipit-source-id: 06c6fb2c18d55490f89fc91ee5b7a4c5a7faf1c6
1147 lines
38 KiB
C++
1147 lines
38 KiB
C++
#include <gtest/gtest.h>
|
|
#include <ATen/core/boxing/impl/test_helpers.h>
|
|
|
|
#include <torch/torch.h>
|
|
#include <ATen/core/op_registration/op_registration.h>
|
|
|
|
#include <torch/csrc/autograd/functions/basic_ops.h>
|
|
|
|
#include <test/cpp/api/support.h>
|
|
|
|
using namespace torch::autograd;
|
|
using namespace torch::test;
|
|
|
|
#define ASSERT_VARIABLE_EQ(a,b) ASSERT_TRUE(torch::allclose((a),(b)))
|
|
#define EXPECT_VARIABLE_EQ(a,b) EXPECT_TRUE(torch::allclose((a),(b)))
|
|
|
|
std::string graph_desc(std::shared_ptr<Node> node) {
|
|
if (!node) {
|
|
return "None";
|
|
}
|
|
auto result = node->name() + "(";
|
|
auto next_edges = node->next_edges();
|
|
for(auto& edge : next_edges) {
|
|
result += graph_desc(edge.function);
|
|
}
|
|
return result+")";
|
|
}
|
|
|
|
Variable simple_fn(const Variable& x, const Variable& y) {
|
|
return x + 2 * y + x * y;
|
|
}
|
|
|
|
TEST(AutogradAPITests, BackwardSimpleTest) {
|
|
Variable x = torch::randn({2, 2}, torch::requires_grad());
|
|
Variable y = torch::randn({2, 2}, torch::requires_grad());
|
|
auto res = simple_fn(x, y);
|
|
backward({res.sum()}, {});
|
|
|
|
ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({2, 2}));
|
|
ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({2, 2})*2);
|
|
}
|
|
|
|
TEST(AutogradAPITests, BackwardTest) {
|
|
Variable x = torch::randn({2, 2}, torch::requires_grad());
|
|
Variable y = torch::randn({2, 2}, torch::requires_grad());
|
|
auto res = simple_fn(x, y);
|
|
backward({res}, {torch::ones({2, 2})}, {}, true);
|
|
|
|
backward({res}, {torch::ones({2, 2})});
|
|
|
|
ASSERT_VARIABLE_EQ(x.grad(), 2* (y + torch::ones({2, 2})));
|
|
ASSERT_VARIABLE_EQ(y.grad(), 2 * (x + torch::ones({2, 2})*2));
|
|
}
|
|
|
|
TEST(AutogradAPITests, GradSimpleTest) {
|
|
// basic grad
|
|
Variable x = torch::randn({2,2}, torch::requires_grad());
|
|
Variable y = torch::randn({2,2}, torch::requires_grad());
|
|
auto res = simple_fn(x, y);
|
|
auto grad_res = grad({res}, {x, y}, {torch::ones({2, 2})});
|
|
|
|
ASSERT_VARIABLE_EQ(grad_res[0], y + torch::ones({2, 2}));
|
|
ASSERT_VARIABLE_EQ(grad_res[1], x + torch::ones({2, 2}) * 2);
|
|
}
|
|
|
|
TEST(AutogradAPITests, GradTest) {
|
|
Variable x = torch::randn({2, 2}, torch::requires_grad());
|
|
Variable y = torch::randn({2, 2}, torch::requires_grad());
|
|
auto res = simple_fn(x, y);
|
|
res.backward(torch::ones({2, 2}), false, true);
|
|
|
|
Variable x_grad = y + torch::ones({2, 2});
|
|
Variable y_grad = x + torch::ones({2, 2}) * 2;
|
|
ASSERT_VARIABLE_EQ(x.grad(), x_grad);
|
|
ASSERT_VARIABLE_EQ(y.grad(), y_grad);
|
|
|
|
Variable grad_sum = 2 * x.grad() + y.grad();
|
|
auto x_hv = grad({grad_sum}, {x}, {torch::ones({2, 2})}, {}, true);
|
|
|
|
ASSERT_VARIABLE_EQ(x_hv[0], torch::ones({2, 2}));
|
|
ASSERT_VARIABLE_EQ(x.grad(), x_grad);
|
|
ASSERT_VARIABLE_EQ(y.grad(), y_grad);
|
|
}
|
|
|
|
TEST(AutogradAPITests, GradNonLeafTest) {
|
|
Variable x_init = torch::randn({2, 2}, torch::requires_grad());
|
|
Variable x = x_init;
|
|
Variable y = torch::randn({2, 2}, torch::requires_grad());
|
|
Variable grad_output = torch::ones({2, 2});
|
|
|
|
for (int i = 0; i < 5; ++ i) {
|
|
auto res = simple_fn(x, y);
|
|
auto input_grads = grad({res}, {x}, {grad_output}, {}, true);
|
|
|
|
Variable grad_x_expected = y + torch::ones({2, 2});
|
|
ASSERT_VARIABLE_EQ(input_grads[0], grad_x_expected);
|
|
ASSERT_FALSE(x.grad().defined());
|
|
ASSERT_FALSE(y.grad().defined());
|
|
x = x + 0.05 * input_grads[0];
|
|
}
|
|
|
|
float val_init = simple_fn(x_init, y).sum().item().toFloat();
|
|
float val_final = simple_fn(x, y).sum().item().toFloat();
|
|
ASSERT_TRUE(val_final > val_init);
|
|
|
|
x.backward(grad_output, false, true);
|
|
ASSERT_TRUE(x_init.grad().defined());
|
|
ASSERT_TRUE(y.grad().defined());
|
|
}
|
|
|
|
TEST(AutogradAPITests, GradUnreachableTest) {
|
|
Variable x = torch::ones({1}, torch::requires_grad());
|
|
Variable y = torch::ones({1}, torch::requires_grad());
|
|
|
|
Variable z = x * 2;
|
|
Variable w = y * 2;
|
|
|
|
auto grad_res = grad({x * 2}, {x, y}, {}, {}, false, true);
|
|
ASSERT_VARIABLE_EQ(grad_res[0], x * 2);
|
|
ASSERT_FALSE(grad_res[1].defined());
|
|
|
|
// This is slightly different than the case above, because z doesn't even
|
|
// have a grad accumulator allocated.
|
|
z = torch::ones({1}, torch::requires_grad());
|
|
grad_res = grad({x * 2}, {x, z}, {}, {}, false, true);
|
|
|
|
ASSERT_VARIABLE_EQ(grad_res[0], x * 2);
|
|
ASSERT_FALSE(grad_res[1].defined());
|
|
|
|
// allow_unused=False, but grads contains None inside, should throw
|
|
ASSERT_THROWS_WITH(grad({x * 2}, {x, y}, {}, {}, false, false), "Set allow_unused=True");
|
|
}
|
|
|
|
TEST(CustomAutogradTest, GradUnreachableDiscoveryTest) {
|
|
// Test that certain nodes are not erroneously executed when an input
|
|
// is unreachable. See #39784
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext *ctx, Variable var) {
|
|
return var;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
|
|
ADD_FAILURE() << "This node should not be executed!";
|
|
return grad_output;
|
|
}
|
|
};
|
|
|
|
auto x = torch::randn(1, torch::requires_grad());
|
|
auto x1 = torch::randn(1);
|
|
auto x2 = MyFunction::apply(x + x1);
|
|
|
|
auto y = torch::randn(1, torch::requires_grad());
|
|
auto grad_res = torch::autograd::grad({x2}, {y}, {}, {}, false, true);
|
|
ASSERT_FALSE(grad_res[0].defined());
|
|
}
|
|
|
|
TEST(AutogradAPITests, EmptyInput) {
|
|
Variable x = torch::ones({1}, torch::requires_grad());
|
|
ASSERT_THROWS_WITH(grad({x * 2}, /*inputs=*/{}, {x}),
|
|
"grad requires non-empty inputs.");
|
|
}
|
|
|
|
TEST(AutogradAPITests, RetainGrad) {
|
|
auto input = torch::rand({1, 3}, torch::requires_grad());
|
|
auto h1 = input * 3;
|
|
auto out = (h1 * h1).sum();
|
|
|
|
{
|
|
// Warning when grad is accessed for non-leaf tensor
|
|
WarningCapture warnings;
|
|
ASSERT_FALSE(h1.grad().defined());
|
|
ASSERT_TRUE(
|
|
warnings.str().find("is not a leaf") != std::string::npos);
|
|
}
|
|
// It should be possible to call retain_grad() multiple times
|
|
h1.retain_grad();
|
|
h1.retain_grad();
|
|
{
|
|
// If retain_grad is true for a non-leaf tensor,
|
|
// there should not be any warning when grad is accessed
|
|
WarningCapture warnings;
|
|
ASSERT_FALSE(h1.grad().defined());
|
|
ASSERT_FALSE(
|
|
warnings.str().find("is not a leaf") != std::string::npos);
|
|
}
|
|
|
|
// Gradient should be accumulated
|
|
// NOLINTNEXTLINE(bugprone-argument-comment)
|
|
out.backward({}, /*keep_graph=*/true);
|
|
ASSERT_VARIABLE_EQ(h1 * 2, h1.grad());
|
|
// NOLINTNEXTLINE(bugprone-argument-comment)
|
|
out.backward({}, /*keep_graph=*/true);
|
|
ASSERT_VARIABLE_EQ(h1 * 4, h1.grad());
|
|
|
|
{
|
|
torch::NoGradGuard no_grad;
|
|
input.grad().zero_();
|
|
}
|
|
// It should be a no-op for leaves
|
|
input.retain_grad();
|
|
input.retain_grad();
|
|
out.backward();
|
|
ASSERT_VARIABLE_EQ(input * 18, input.grad());
|
|
}
|
|
|
|
TEST(AutogradAPITests, AnomalyMode) {
|
|
// Needs to have backtrace as warning and then throw an error
|
|
torch::autograd::DetectAnomalyGuard detect_anomaly;
|
|
{
|
|
WarningCapture warnings;
|
|
auto x = torch::tensor({5.0}, torch::requires_grad());
|
|
auto y = x * x;
|
|
auto z = y * y;
|
|
y += 1;
|
|
ASSERT_THROWS_WITH(z.backward(), "inplace");
|
|
ASSERT_TRUE(
|
|
warnings.str().find("Traceback of forward") != std::string::npos);
|
|
}
|
|
{
|
|
WarningCapture warnings;
|
|
// Double backward
|
|
auto x = torch::tensor({0.0}, torch::requires_grad());
|
|
auto y = x.pow(1.5);
|
|
auto gr =
|
|
// NOLINTNEXTLINE(bugprone-argument-comment)
|
|
grad({y}, {x}, {}, /*retain_graph=*/true, /*create_backward=*/true);
|
|
ASSERT_THROWS_WITH(grad({gr[0]}, {x}, {torch::tensor({0.0})});, "returned nan");
|
|
auto msgs = warnings.messages();
|
|
ASSERT_EQ(msgs.size(), 2);
|
|
ASSERT_TRUE(
|
|
msgs[0].find("Traceback of forward call that caused the error") !=
|
|
std::string::npos);
|
|
ASSERT_TRUE(
|
|
msgs[1].find(
|
|
"Traceback of forward call that induced the previous calculation") !=
|
|
std::string::npos);
|
|
}
|
|
}
|
|
|
|
TEST(CustomAutogradTest, CustomFunction) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext *ctx, Variable var1, int mul, Variable var2) {
|
|
ctx->saved_data["mul"] = mul;
|
|
ctx->save_for_backward({var1, var2});
|
|
return var1 + mul*var2 + var1*var2;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
|
|
int mul = ctx->saved_data["mul"].toInt();
|
|
auto saved = ctx->get_saved_variables();
|
|
auto var1 = saved[0];
|
|
auto var2 = saved[1];
|
|
variable_list output = {grad_output[0] + grad_output[0]*var2, Variable(), grad_output[0] * mul + grad_output[0] * var1};
|
|
return output;
|
|
}
|
|
};
|
|
|
|
Variable x = torch::randn({5,5}, torch::requires_grad());
|
|
Variable y = torch::randn({5,5}, torch::requires_grad());
|
|
auto res = MyFunction::apply(x,2,y);
|
|
auto go = torch::ones({}, torch::requires_grad());
|
|
res.sum().backward(go, false, true);
|
|
|
|
ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({5,5}));
|
|
ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({5,5})*2);
|
|
}
|
|
|
|
TEST(CustomAutogradTest, FunctionReturnsInput) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext *ctx, Variable var1) {
|
|
return var1;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
|
|
return {grad_output[0]*2};
|
|
}
|
|
};
|
|
|
|
Variable x(torch::ones(1, torch::requires_grad()));
|
|
MyFunction::apply(x).backward(torch::ones(1) , true, true);
|
|
ASSERT_VARIABLE_EQ(x.grad(), torch::full(1, 2.));
|
|
}
|
|
|
|
TEST(CustomAutogradTest, FunctionReturnsUndefined) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext *ctx, Variable var) {
|
|
return var * 2;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
|
|
at::Tensor undefined_tensor;
|
|
return {undefined_tensor};
|
|
}
|
|
};
|
|
|
|
auto x = torch::ones(1, torch::requires_grad());
|
|
|
|
MyFunction::apply(x).backward();
|
|
ASSERT_FALSE(x.grad().defined());
|
|
|
|
MyFunction::apply(x.pow(2)).backward();
|
|
ASSERT_FALSE(x.grad().defined());
|
|
|
|
MyFunction::apply(x).sum().backward();
|
|
ASSERT_FALSE(x.grad().defined());
|
|
|
|
ASSERT_FALSE(torch::autograd::grad(
|
|
{MyFunction::apply(x)}, {x}, {}, false, false, true)[0].defined());
|
|
}
|
|
|
|
TEST(CustomAutogradTest, MaterializeGrads) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext *ctx, Variable var) {
|
|
return var;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
|
|
EXPECT_VARIABLE_EQ(grad_output[0], torch::zeros(1));
|
|
return grad_output;
|
|
}
|
|
};
|
|
|
|
auto x = torch::ones(1, torch::requires_grad());
|
|
UndefinedGrad().apply({MyFunction::apply(x)})[0].backward();
|
|
}
|
|
|
|
TEST(CustomAutogradTest, DontMaterializeGrads) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext *ctx, Variable var) {
|
|
ctx->set_materialize_grads(false);
|
|
return var;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
|
|
EXPECT_FALSE(grad_output[0].defined());
|
|
return grad_output;
|
|
}
|
|
};
|
|
|
|
auto x = torch::ones(1, torch::requires_grad());
|
|
UndefinedGrad().apply({MyFunction::apply(x)})[0].backward();
|
|
}
|
|
|
|
TEST(CustomAutogradTest, NoGradCustomFunction) {
|
|
// Custom Function should respect grad mode
|
|
struct MyOp : public Function<MyOp> {
|
|
static Variable forward(AutogradContext *ctx, Variable x) {
|
|
return x+1;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list dy) {
|
|
return dy;
|
|
}
|
|
};
|
|
|
|
auto x = torch::ones({5,5}, torch::requires_grad());
|
|
{
|
|
at::NoGradGuard no_grad;
|
|
auto y = MyOp::apply(x);
|
|
ASSERT_FALSE(y.requires_grad());
|
|
}
|
|
}
|
|
|
|
TEST(CustomAutogradTest, MarkDirty) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext *ctx, Variable v) {
|
|
// Change the value inplace
|
|
auto v_data = v.data_ptr<float>();
|
|
v_data[0] = 2;
|
|
ctx->mark_dirty({v});
|
|
return v;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
|
|
return { (grad_output[0]*2.0) };
|
|
}
|
|
};
|
|
|
|
// Clone here because modifying leafs inplace is not allowed
|
|
auto x = torch::randn({5,5}, torch::requires_grad()).clone();
|
|
auto version_before = x._version();
|
|
auto out = MyFunction::apply(x);
|
|
auto version_after = x._version();
|
|
ASSERT_TRUE(version_after >= (version_before + 1));
|
|
out.sum().backward();
|
|
}
|
|
|
|
TEST(CustomAutogradTest, MarkNonDifferentiable) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext *ctx, Variable v) {
|
|
Variable output = v > 0;
|
|
ctx->mark_non_differentiable({output});
|
|
return output;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
|
|
return { (grad_output[0]*0.0) };
|
|
}
|
|
};
|
|
|
|
auto x = torch::randn({5,5}, torch::requires_grad());
|
|
auto mask = MyFunction::apply(x);
|
|
ASSERT_FALSE(mask.requires_grad());
|
|
auto y = x.masked_fill(mask, 0);
|
|
y.sum().backward();
|
|
}
|
|
|
|
TEST(CustomAutogradTest, MarkNonDifferentiableMixed) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static variable_list forward(AutogradContext *ctx, Variable input) {
|
|
Variable a = input+1;
|
|
Variable b = input+2;
|
|
ctx->mark_non_differentiable({a});
|
|
return {a,b};
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
|
|
const Variable &grad_a = grad_output[0], &grad_b = grad_output[1];
|
|
EXPECT_VARIABLE_EQ(grad_a, torch::zeros({5,5}));
|
|
EXPECT_VARIABLE_EQ(grad_b, torch::ones({5,5}));
|
|
return {grad_b};
|
|
}
|
|
};
|
|
|
|
auto x = torch::randn({5,5}, torch::requires_grad());
|
|
auto out = MyFunction::apply(x);
|
|
|
|
ASSERT_FALSE(out[0].requires_grad());
|
|
ASSERT_TRUE(out[1].requires_grad());
|
|
out[1].sum().backward();
|
|
ASSERT_VARIABLE_EQ(x.grad(), torch::ones({5,5}));
|
|
}
|
|
|
|
TEST(CustomAutogradTest, MarkNonDifferentiableNone) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext *ctx, Variable input) {
|
|
auto output = input.clone();
|
|
ctx->mark_non_differentiable({output});
|
|
return output;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_outputs) {
|
|
return {};
|
|
}
|
|
};
|
|
|
|
auto x = torch::randn({5,5}, torch::requires_grad());
|
|
auto r = MyFunction::apply(x * x);
|
|
(r * x).sum().backward();
|
|
}
|
|
|
|
TEST(CustomAutogradTest, ReturnLeafInplace) {
|
|
struct Inplace : public Function<Inplace> {
|
|
static variable_list forward(AutogradContext *ctx, Variable a, Variable b) {
|
|
ctx->mark_dirty({a});
|
|
return {a.add_(b), b+2};
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
|
|
return {grad_output[0], grad_output[0] + grad_output[1]};
|
|
}
|
|
};
|
|
|
|
Variable x = torch::randn({5,5});
|
|
Variable y = torch::randn({5,5}, torch::requires_grad());
|
|
|
|
auto out = Inplace::apply(x,y);
|
|
auto &q = out[0];
|
|
ASSERT_TRUE(torch::equal(q, x));
|
|
ASSERT_TRUE(q.requires_grad());
|
|
q.sum().backward();
|
|
ASSERT_VARIABLE_EQ(y.grad(), torch::ones({5,5}));
|
|
}
|
|
|
|
TEST(CustomAutogradTest, ReturnDuplicateInplace) {
|
|
struct DoubleInplace : public Function<DoubleInplace> {
|
|
static variable_list forward(AutogradContext *ctx, Variable x) {
|
|
x.mul_(2);
|
|
ctx->mark_dirty({x});
|
|
return {x,x};
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctsx, variable_list grad_outputs) {
|
|
return {grad_outputs[0]*2 + grad_outputs[1]*2};
|
|
}
|
|
};
|
|
|
|
auto x = torch::randn({5,5}, torch::requires_grad());
|
|
|
|
ASSERT_THROWS_WITH(DoubleInplace::apply(x), "leaf Variable that requires grad");
|
|
// TODO ASSERT_THROWS_WITH(DoubleInplace::apply(x.clone()[0]), "only one output");
|
|
|
|
auto out = DoubleInplace::apply(x.clone());
|
|
ASSERT_TRUE(torch::equal(out[0],out[1]));
|
|
}
|
|
|
|
TEST(CustomAutogradTest, ReturnDuplicate) {
|
|
struct DoubleDuplicate : public Function<DoubleDuplicate> {
|
|
static variable_list forward(AutogradContext *ctx, Variable x) {
|
|
auto output = x*2;
|
|
return {output, output};
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_outputs) {
|
|
return {grad_outputs[0]*2 + grad_outputs[1]*2};
|
|
}
|
|
};
|
|
|
|
auto x = torch::randn({5,5}, torch::requires_grad());
|
|
auto out = DoubleDuplicate::apply(x);
|
|
ASSERT_TRUE(torch::equal(out[0],out[1]));
|
|
}
|
|
|
|
TEST(CustomAutogradTest, SaveEmptyForBackward) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext *ctx, Variable input) {
|
|
ctx->save_for_backward({Variable(), input, Variable()});
|
|
return input*input;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
|
|
auto saved = ctx->get_saved_variables();
|
|
EXPECT_FALSE(saved[0].defined());
|
|
EXPECT_FALSE(saved[2].defined());
|
|
return {saved[1] * 2 * grad_output[0]};
|
|
}
|
|
};
|
|
|
|
Variable x = torch::randn({5,5}, torch::requires_grad());
|
|
auto y = MyFunction::apply(x);
|
|
y.sum().backward();
|
|
ASSERT_VARIABLE_EQ(x.grad(), 2*x);
|
|
}
|
|
|
|
TEST(CustomAutogradTest, InvalidGradients) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext *ctx, Variable x) {
|
|
return x*2;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctsx, variable_list grad_outputs) {
|
|
return {torch::randn(10, torch::dtype(torch::kFloat).requires_grad(true))};
|
|
}
|
|
};
|
|
|
|
auto input1 = torch::randn({5,5}, torch::dtype(torch::kFloat).requires_grad(true));
|
|
ASSERT_THROWS_WITH(
|
|
MyFunction::apply(input1).sum().backward(), "expected shape");
|
|
auto input2 = torch::randn(10, torch::dtype(torch::kDouble).requires_grad(true));
|
|
}
|
|
|
|
TEST(CustomAutogradTest, NoGradInput) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext*, Variable x) {
|
|
return x;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext*, variable_list grad_outputs) {
|
|
return grad_outputs;
|
|
}
|
|
};
|
|
|
|
Variable x = torch::randn({5,5}, torch::requires_grad());
|
|
Variable y;
|
|
{
|
|
at::NoGradGuard no_grad;
|
|
y = MyFunction::apply(x);
|
|
}
|
|
|
|
ASSERT_TRUE(x.requires_grad());
|
|
ASSERT_FALSE(y.grad_fn());
|
|
}
|
|
|
|
TEST(CustomAutogradTest, TooManyGrads) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext*, Variable input) {
|
|
return input;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext*, variable_list grad_output) {
|
|
grad_output.insert(grad_output.end(), {Variable(), Variable()});
|
|
return grad_output;
|
|
}
|
|
};
|
|
}
|
|
|
|
TEST(CustomAutogradTest, DepNoGrad) {
|
|
struct F1 : public Function<F1> {
|
|
static variable_list forward(AutogradContext *ctx, Variable input) {
|
|
auto out = torch::randn(input.sizes());
|
|
ctx->mark_non_differentiable({out});
|
|
return {input, out};
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
|
|
return {grad_output[0]};
|
|
}
|
|
};
|
|
|
|
struct F2 : public Function<F2> {
|
|
static Variable forward(AutogradContext*, Variable input, Variable ignore) {
|
|
return input;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext*, variable_list grad_output) {
|
|
return {grad_output[0], Variable()};
|
|
}
|
|
};
|
|
|
|
auto x = torch::randn(5, torch::requires_grad());
|
|
auto out = F1::apply(x);
|
|
Variable &a = out[0], &b = out[1];
|
|
b = b+1; // Separate F1 and F2 by another operation
|
|
ASSERT_TRUE(a.requires_grad());
|
|
ASSERT_FALSE(b.requires_grad());
|
|
|
|
auto c = F2::apply(a,b);
|
|
c.backward(torch::ones(c.sizes()), false, false);
|
|
ASSERT_VARIABLE_EQ(x.grad(), torch::ones(x.sizes()));
|
|
}
|
|
|
|
TEST(CustomAutogradTest, Reentrant) {
|
|
static Variable y_data = torch::randn({2, 2});
|
|
struct Reenter : public Function<Reenter> {
|
|
static Variable forward(AutogradContext *ctx, Variable input) {
|
|
Variable output;
|
|
{
|
|
at::AutoGradMode enable_grad(true);
|
|
auto x = make_variable(input.tensor_data(), true);
|
|
auto y = make_variable(y_data.tensor_data(), true);
|
|
output = x*y;
|
|
|
|
ctx->saved_data["x"] = x;
|
|
ctx->saved_data["y"] = y;
|
|
ctx->saved_data["output_var"] = output;
|
|
}
|
|
return output.detach();
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
|
|
{
|
|
at::AutoGradMode enable_grad(true);
|
|
auto out = ctx->saved_data["output_var"].toTensor();
|
|
out.sum().backward();
|
|
}
|
|
return {ctx->saved_data["x"].toTensor().grad() * grad_output[0]};
|
|
}
|
|
};
|
|
|
|
auto x = torch::randn({2,2}, torch::requires_grad());
|
|
auto out = Reenter::apply(x);
|
|
out.sum().backward();
|
|
ASSERT_VARIABLE_EQ(x.grad(), y_data);
|
|
}
|
|
|
|
|
|
// NOTE: If this fails for apparently unrelated reasons in TSAN be aware of
|
|
// the TSAN limit on mutex: https://github.com/google/sanitizers/issues/950
|
|
TEST(CustomAutogradTest, DeepReentrant) {
|
|
struct DeepReenter : public Function<DeepReenter> {
|
|
static Variable forward(AutogradContext *ctx, Variable x) {
|
|
{
|
|
at::AutoGradMode enable_grad(true);
|
|
ctx->saved_data["x"] = make_variable(x.tensor_data(), true) -1;
|
|
}
|
|
return ctx->saved_data["x"].toTensor().detach();
|
|
}
|
|
|
|
static variable_list backward(AutogradContext*ctx, variable_list grad_output) {
|
|
if (!ctx->saved_data["x"].toTensor().is_nonzero()) {
|
|
return grad_output;
|
|
}
|
|
{
|
|
at::AutoGradMode enable_grad(true);
|
|
apply(ctx->saved_data["x"].toTensor())[0].sum().backward();
|
|
return grad_output;
|
|
}
|
|
}
|
|
};
|
|
|
|
// This should not stack overflow
|
|
auto v = torch::tensor({8193}, torch::dtype(torch::kFloat).requires_grad(true));
|
|
DeepReenter::apply(v).sum().backward();
|
|
}
|
|
|
|
TEST(CustomAutogradTest, ReentrantPriority) {
|
|
static std::vector<int> order;
|
|
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext*, Variable x) {
|
|
return x;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext*, variable_list grad) {
|
|
order.push_back(0);
|
|
return grad;
|
|
}
|
|
};
|
|
|
|
struct Reenter : public Function<Reenter> {
|
|
static Variable forward(AutogradContext *ctx, Variable x) {
|
|
{
|
|
at::AutoGradMode enable_grad(true);
|
|
ctx->saved_data["x"] = make_variable(x.tensor_data(), true) -1;
|
|
}
|
|
return ctx->saved_data["x"].toTensor().detach();
|
|
}
|
|
|
|
static variable_list backward(AutogradContext*ctx, variable_list grad_output) {
|
|
order.push_back(1);
|
|
if (!ctx->saved_data["x"].toTensor().is_nonzero()) {
|
|
return grad_output;
|
|
}
|
|
{
|
|
at::AutoGradMode enable_grad(true);
|
|
apply(ctx->saved_data["x"].toTensor())[0].sum().backward();
|
|
return grad_output;
|
|
}
|
|
}
|
|
};
|
|
|
|
auto a = MyFunction::apply(torch::tensor({6}, torch::dtype(torch::kFloat).requires_grad(true)));
|
|
auto b = Reenter::apply(torch::tensor({9}, torch::dtype(torch::kFloat).requires_grad(true)));
|
|
auto v = a*b;
|
|
v.backward();
|
|
|
|
|
|
// All the reentrant tasks should be prioritized over the MyFunction backward
|
|
// task.
|
|
ASSERT_EQ(order.size(), 10);
|
|
ASSERT_EQ(std::count(order.begin(), order.end(), 1), 9);
|
|
ASSERT_EQ(order.back(), 0);
|
|
// Clear static variable in case test get executed in a loop
|
|
order.clear();
|
|
}
|
|
|
|
TEST(CustomAutogradTest, Hooks) {
|
|
Variable x = torch::ones({5,5}, torch::requires_grad());
|
|
Variable y = torch::ones({5,5})*4;
|
|
y.set_requires_grad(true);
|
|
|
|
int counter = 0;
|
|
|
|
std::function<void(int, Variable)> bw_hook([&counter](int inc, Variable grad){
|
|
counter += inc;
|
|
});
|
|
|
|
Variable z = x * x + x * 2 + x * y + y;
|
|
x.register_hook([&bw_hook](Variable grad){
|
|
bw_hook(0, grad);
|
|
});
|
|
auto hook_1 = z.register_hook([&bw_hook](Variable grad){
|
|
bw_hook(1, grad);
|
|
});
|
|
z.backward(torch::ones({5,5}), true, true);
|
|
ASSERT_EQ(counter, 1);
|
|
|
|
auto hook_2 = z.register_hook([&bw_hook](Variable grad){
|
|
bw_hook(2, grad);
|
|
});
|
|
z.backward(torch::ones({5,5}), true, true);
|
|
ASSERT_EQ(counter, 4);
|
|
|
|
z.remove_hook(hook_2);
|
|
z.backward(torch::ones({5,5}), true, true);
|
|
ASSERT_EQ(counter, 5);
|
|
|
|
std::function<Variable(Variable)> bw_hook_modify([](Variable grad){
|
|
return grad.mul(2);
|
|
});
|
|
|
|
z.remove_hook(hook_1);
|
|
z.register_hook(bw_hook_modify);
|
|
y.grad().zero_();
|
|
z.backward(torch::ones({5,5}), true, false);
|
|
ASSERT_VARIABLE_EQ(y.grad(), (x+1)*2);
|
|
|
|
y.register_hook(bw_hook_modify);
|
|
y.grad().zero_();
|
|
z.backward(torch::ones({5,5}), false, false);
|
|
ASSERT_VARIABLE_EQ(y.grad(), (x+1)*4);
|
|
|
|
ASSERT_THROWS_WITH(y.remove_hook(3), "Invalid index");
|
|
}
|
|
|
|
TEST(CustomAutogradTest, HookNone) {
|
|
struct NoneGradientFunction : public Function<NoneGradientFunction> {
|
|
static variable_list forward(AutogradContext *ctx, Variable x, Variable y) {
|
|
return {x,y};
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad) {
|
|
return {grad[0], Variable()};
|
|
}
|
|
};
|
|
|
|
bool was_called = false;
|
|
|
|
auto hook = ([&was_called](Variable grad){
|
|
ASSERT_TRUE(grad.defined());
|
|
was_called = true;
|
|
});
|
|
|
|
auto x = torch::randn({5,5}, torch::requires_grad());
|
|
auto y = torch::randn({5,5});
|
|
|
|
auto out = NoneGradientFunction::apply(x,y);
|
|
Variable rx = x[0], ry = x[1];
|
|
|
|
rx.register_hook(hook);
|
|
ry.register_hook(hook);
|
|
(rx+ry).sum().backward();
|
|
ASSERT_TRUE(was_called);
|
|
}
|
|
|
|
TEST(CustomAutogradTest, BackwardWithInputs) {
|
|
Variable x = torch::randn({5,5}, torch::requires_grad());
|
|
Variable y = torch::randn({5,5}, torch::requires_grad());
|
|
Variable z = x * x + x * y + y * y;
|
|
Variable x_grad_expected = 2 * x + y;
|
|
Variable y_grad_expected = x + 2 * y;
|
|
|
|
z.backward(torch::ones({5, 5}), false, false, {x});
|
|
|
|
ASSERT_VARIABLE_EQ(x.grad(), x_grad_expected);
|
|
ASSERT_FALSE(y.grad().defined());
|
|
}
|
|
|
|
TEST(CustomAutogradTest, BackwardWithEmptyInputs) {
|
|
Variable x = torch::randn({5,5}, torch::requires_grad());
|
|
Variable y = torch::randn({5,5}, torch::requires_grad());
|
|
Variable z = x * x + x * y + y * y;
|
|
Variable x_grad_expected = 2 * x + y;
|
|
Variable y_grad_expected = x + 2 * y;
|
|
ASSERT_THROWS_WITH(z.backward(torch::ones({5, 5}), false, false, std::vector<Variable>{}), "cannot be empty");
|
|
}
|
|
|
|
TEST(CustomAutogradTest, BackwardWithNonLeafInputs) {
|
|
Variable x = torch::randn({5,5}, torch::requires_grad());
|
|
Variable y = torch::randn({5,5}, torch::requires_grad());
|
|
Variable z = x * x;
|
|
Variable w = y * z + x * y + y * y;
|
|
|
|
Variable x_grad_expected = 2 * x * y + y;
|
|
Variable z_grad_expected = y;
|
|
|
|
w.backward(torch::ones({5, 5}), false, false, std::vector<Variable>{x, z});
|
|
|
|
ASSERT_VARIABLE_EQ(x.grad(), x_grad_expected);
|
|
ASSERT_VARIABLE_EQ(z.grad(), z_grad_expected);
|
|
ASSERT_FALSE(y.grad().defined());
|
|
}
|
|
|
|
TEST(CustomAutogradTest, BackwardWithCreateGraphWarns) {
|
|
c10::Warning::WarnAlways guard(true);
|
|
|
|
torch::Tensor x = torch::randn({5,5}).set_requires_grad(true);
|
|
auto z = x * x;
|
|
{
|
|
WarningCapture warnings;
|
|
z.backward(torch::ones({5, 5}), c10::nullopt, true);
|
|
ASSERT_TRUE(
|
|
warnings.str().find("Using backward() with create_graph=True") != std::string::npos);
|
|
}
|
|
|
|
{
|
|
WarningCapture warnings;
|
|
torch::autograd::backward({z}, {torch::ones({5, 5})}, c10::nullopt, true);
|
|
ASSERT_TRUE(
|
|
warnings.str().find("Using backward() with create_graph=True") != std::string::npos);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Tests for AutogradNotImplementedFallback
|
|
* - Check that we created the NotImplemented kernel when inputs require grad
|
|
* but when no inputs require grad, we should not create this node
|
|
* - check_inplace logic
|
|
* - view ops (TODO: not an official view yet, update this once InplaceOrView kernel is landed)
|
|
* - TODO: Tests for NDEBUG checks?
|
|
* - tensorlist input and output
|
|
* - multiple outputs / non-tensor output
|
|
* - rebase_history vs set_history
|
|
*/
|
|
namespace {
|
|
|
|
torch::Tensor inplace_op(const torch::Tensor& self, const torch::Tensor& other) {
|
|
return self.add_(other);
|
|
}
|
|
|
|
std::tuple<torch::Tensor, torch::Tensor> two_arg_inplace_op(const torch::Tensor& self, const torch::Tensor& other) {
|
|
other.add_(self);
|
|
self.add_(other);
|
|
return std::tuple<torch::Tensor, torch::Tensor>(self, other);
|
|
}
|
|
|
|
std::tuple<torch::Tensor, torch::Tensor> two_pairs_of_view_op(const torch::Tensor& self, const torch::Tensor& other) {
|
|
// This is not allowed. We test below that this calling into the boxed kernel will raise an error
|
|
auto self_view = self.view(-1);
|
|
auto other_view = other.view(-1);
|
|
return std::tuple<torch::Tensor, torch::Tensor>(self_view, other_view);
|
|
}
|
|
|
|
int64_t ret_single_non_tensor(const torch::Tensor& self, const torch::Tensor& other) {
|
|
return 12;
|
|
}
|
|
|
|
torch::Tensor opt_op(const torch::Tensor& self, const c10::optional<at::Tensor>& other) {
|
|
if (other.has_value()) {
|
|
return self + other.value();
|
|
} else {
|
|
return self.clone();
|
|
}
|
|
}
|
|
|
|
torch::Tensor my_custom_op(const torch::Tensor& self, const torch::Tensor& other) {
|
|
return self + other;
|
|
}
|
|
|
|
std::tuple<torch::Tensor, torch::Tensor, int64_t> ret_tuple_non_tensor(const torch::Tensor& self, const torch::Tensor& other) {
|
|
auto a = self - other;
|
|
auto b = self + other;
|
|
return std::tuple<torch::Tensor, torch::Tensor, int64_t>(a, b, 12);
|
|
}
|
|
|
|
torch::Tensor view_op(const torch::Tensor& self, const torch::Tensor& other) {
|
|
return self.view(-1);
|
|
}
|
|
|
|
std::vector<at::Tensor> ret_tensor_vector(const torch::Tensor& self, const torch::Tensor& other) {
|
|
std::vector<at::Tensor> out;
|
|
out.push_back(self + other);
|
|
out.push_back(self - other);
|
|
return out;
|
|
}
|
|
|
|
torch::Tensor tensorlist_op(const torch::Tensor& self, at::TensorList other) {
|
|
const auto& res = self.clone();
|
|
for (const auto& t : other) {
|
|
res.add_(t);
|
|
}
|
|
return res;
|
|
}
|
|
|
|
#define REGISTER_TEST_OP(name, schema, fn) \
|
|
auto m = MAKE_TORCH_LIBRARY(_test); \
|
|
m.def(schema); \
|
|
auto m_autograd = MAKE_TORCH_LIBRARY_IMPL(_test, Autograd); \
|
|
auto m_cpu = MAKE_TORCH_LIBRARY_IMPL(_test, CPU); \
|
|
m_cpu.impl(name, c10::DispatchKey::CPU, TORCH_FN(fn)); \
|
|
m_autograd.impl(name, c10::DispatchKey::Autograd, autogradNotImplementedFallback());
|
|
|
|
template <typename F>
|
|
void assertBasicChecks(F op) {
|
|
auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
|
|
auto b = torch::tensor({1.}, {torch::kFloat32});
|
|
auto c = torch::tensor({1.}, {torch::kFloat32});
|
|
|
|
// If any inputs require grad,
|
|
auto out1 = op(a, b);
|
|
ASSERT_THROWS_WITH(out1.backward(), "is not implemented");
|
|
|
|
// # Should not have grad_fn if none require grad
|
|
auto out2 = op(b, c);
|
|
ASSERT_THROWS_WITH(out2.backward(), "element 0 of tensors does not require grad and does not have a grad_fn");
|
|
|
|
// TODO: Forward AD Tests?
|
|
}
|
|
|
|
} // namespace
|
|
|
|
// These tests trigger an MSVC bug in the internal arvr build
|
|
// Reproduce with: buck build @arvr/mode/win/opt //xplat/caffe2:autograd_libtorch_test_ovrsource
|
|
// It is probably caused by the lambda, see https://github.com/pytorch/pytorch/issues/48763
|
|
#if !defined(_MSC_VER)
|
|
|
|
TEST(TestAutogradNotImplementedFallback, RetSingleNonTensor) {
|
|
REGISTER_TEST_OP("ret_single_non_tensor", "_test::ret_single_non_tensor(Tensor self, Tensor other) -> int", ret_single_non_tensor);
|
|
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::ret_single_non_tensor", "");
|
|
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
|
|
return callOpUnboxed<int64_t, const torch::Tensor&, const torch::Tensor&>(opHandle, _1, _2);
|
|
};
|
|
|
|
auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
|
|
auto b = torch::tensor({1.}, {torch::kFloat32});
|
|
|
|
ASSERT_EQ(op(a, b), ret_single_non_tensor(a, b));
|
|
}
|
|
|
|
TEST(TestAutogradNotImplementedFallback, DoubleViewOP) {
|
|
REGISTER_TEST_OP("two_pairs_of_view_op", "_test::two_pairs_of_view_op(Tensor(a) self, Tensor(b) other) -> (Tensor(a), Tensor(b))", two_pairs_of_view_op);
|
|
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::two_pairs_of_view_op", "");
|
|
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
|
|
return callOpUnboxed<std::tuple<torch::Tensor, torch::Tensor>, const torch::Tensor&, const torch::Tensor&>(opHandle, _1, _2);
|
|
};
|
|
auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
|
|
auto b = torch::tensor({1.}, {torch::kFloat32});
|
|
ASSERT_THROWS_WITH(op(a, b),
|
|
"Expected only a single output in the operator schema to have a non-write alias annotation");
|
|
}
|
|
|
|
TEST(TestAutogradNotImplementedFallback, InplaceOp) {
|
|
REGISTER_TEST_OP("inplace_op", "_test::inplace_op(Tensor(a!) self, Tensor other) -> Tensor(a!)", inplace_op);
|
|
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::inplace_op", "");
|
|
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
|
|
return callOpUnboxed<torch::Tensor, const torch::Tensor&, const torch::Tensor&>(opHandle, _1, _2);
|
|
};
|
|
|
|
auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
|
|
auto b = torch::tensor({1.}, {torch::kFloat32});
|
|
|
|
// Check in-place
|
|
ASSERT_THROWS_WITH(op(a, b),
|
|
"a leaf Variable that requires grad is being used in an in-place operation");
|
|
op(b, a);
|
|
a = a.clone();
|
|
b = b.clone();
|
|
auto c = op(a, b);
|
|
ASSERT_TRUE(torch::allclose(c, inplace_op(a, b)));
|
|
|
|
// Test in-place on view
|
|
auto base = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone();
|
|
auto view = base.view(-1);
|
|
auto t = torch::tensor({1.}, {torch::kFloat32});
|
|
|
|
torch::Tensor v_nograd;
|
|
{
|
|
c10::NoGradGuard guard;
|
|
v_nograd = base.view(-1);
|
|
op(v_nograd, t);
|
|
}
|
|
|
|
ASSERT_THROWS_WITH(op(v_nograd, t), "A view was created in no_grad mode");
|
|
ASSERT_EQ(op(view, t).unsafeGetTensorImpl(), view.unsafeGetTensorImpl());
|
|
|
|
// TODO: once we have InplaceOrView kernel, renable this since version counter would actually
|
|
// be incremented
|
|
// ASSERT_THAT(op(view, t).grad_fn()->name(), ::testing::HasSubstr("AsStridedBackward"));
|
|
}
|
|
|
|
TEST(TestAutogradNotImplementedFallback, DoubleInplaceOp) {
|
|
REGISTER_TEST_OP("two_arg_inplace_op", "_test::two_arg_inplace_op(Tensor(a!) self, Tensor(b!) other) -> (Tensor(a!), Tensor(b!))", two_arg_inplace_op);
|
|
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::two_arg_inplace_op", "");
|
|
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
|
|
return callOpUnboxed<std::tuple<torch::Tensor, torch::Tensor>, const torch::Tensor&, const torch::Tensor&>(opHandle, _1, _2);
|
|
};
|
|
auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
|
|
auto b = torch::tensor({1.}, {torch::kFloat32});
|
|
|
|
// Both are modified in-place!
|
|
ASSERT_THROWS_WITH(op(a, b),
|
|
"a leaf Variable that requires grad is being used in an in-place operation");
|
|
ASSERT_THROWS_WITH(op(b, a),
|
|
"a leaf Variable that requires grad is being used in an in-place operation");
|
|
}
|
|
|
|
TEST(TestAutogradNotImplementedFallback, OptOp) {
|
|
REGISTER_TEST_OP("opt_op", "_test::opt_op(Tensor self, Tensor? other) -> Tensor", opt_op);
|
|
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::opt_op", "");
|
|
auto op = [&](const torch::Tensor& _1, const c10::optional<torch::Tensor>& _2) {
|
|
return callOpUnboxed<torch::Tensor, const torch::Tensor&, const c10::optional<torch::Tensor>&>(opHandle, _1, _2);
|
|
};
|
|
|
|
auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
|
|
auto b = torch::tensor({1.}, {torch::kFloat32});
|
|
|
|
ASSERT_TRUE(torch::allclose(op(a, b), opt_op(a, b)));
|
|
ASSERT_TRUE(torch::allclose(op(a, {}), opt_op(a, {})));
|
|
}
|
|
|
|
TEST(TestAutogradNotImplementedFallback, OutOfPlaceAddition) {
|
|
REGISTER_TEST_OP("my_custom_op", "_test::my_custom_op(Tensor self, Tensor other) -> Tensor", my_custom_op);
|
|
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::my_custom_op", "");
|
|
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
|
|
return callOpUnboxed<torch::Tensor, const torch::Tensor&, const torch::Tensor&>(opHandle, _1, _2);
|
|
};
|
|
|
|
assertBasicChecks(op);
|
|
}
|
|
|
|
TEST(TestAutogradNotImplementedFallback, RetTupleNonTensor) {
|
|
REGISTER_TEST_OP("ret_tuple_non_tensor", "_test::ret_tuple_non_tensor(Tensor self, Tensor other) -> (Tensor, Tensor, int)", ret_tuple_non_tensor);
|
|
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::ret_tuple_non_tensor", "");
|
|
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
|
|
torch::Tensor out0;
|
|
torch::Tensor out1;
|
|
int64_t out2;
|
|
auto out = callOpUnboxed<std::tuple<torch::Tensor, torch::Tensor, int64_t>, const torch::Tensor&, const torch::Tensor&>(opHandle, _1, _2);
|
|
std::tie(out0, out1, out2) = std::move(out);
|
|
return out0;
|
|
};
|
|
|
|
assertBasicChecks(op);
|
|
}
|
|
|
|
TEST(TestAutogradNotImplementedFallback, ViewOp) {
|
|
REGISTER_TEST_OP("view_op", "_test::view_op(Tensor(a) self, Tensor other) -> Tensor(a)", view_op);
|
|
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::view_op", "");
|
|
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
|
|
return callOpUnboxed<torch::Tensor, const torch::Tensor&, const torch::Tensor&>(opHandle, _1, _2);
|
|
};
|
|
assertBasicChecks(op);
|
|
}
|
|
|
|
TEST(TestAutogradNotImplementedFallback, RetTensorVector) {
|
|
REGISTER_TEST_OP("ret_tensor_vector", "_test::ret_tensor_vector(Tensor self, Tensor other) -> Tensor[]", ret_tensor_vector);
|
|
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::ret_tensor_vector", "");
|
|
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
|
|
return callOpUnboxed<std::vector<at::Tensor>, const torch::Tensor&, const torch::Tensor&>(opHandle, _1, _2)[0];
|
|
};
|
|
assertBasicChecks(op);
|
|
}
|
|
|
|
TEST(TestAutogradNotImplementedFallback, TensorlistOp) {
|
|
REGISTER_TEST_OP("tensorlist_op", "_test::tensorlist_op(Tensor self, Tensor[] other) -> Tensor", tensorlist_op);
|
|
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::tensorlist_op", "");
|
|
auto op = [&](torch::Tensor _1, at::TensorList _2) {
|
|
return callOpUnboxed<torch::Tensor, const torch::Tensor&, at::TensorList>(opHandle, _1, _2);
|
|
};
|
|
|
|
auto a = torch::tensor({1.}, {torch::kFloat32});
|
|
auto b = torch::tensor({1.}, {torch::kFloat32});
|
|
auto c = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
|
|
std::vector<torch::Tensor> vec = {b, c};
|
|
auto out = op(a, vec);
|
|
|
|
ASSERT_THROWS_WITH(torch::autograd::grad({out}, {vec[0]}), "One of the differentiated Tensors does not require grad");
|
|
ASSERT_THROWS_WITH(torch::autograd::grad({out}, {vec[1]}), "is not implemented");
|
|
|
|
ASSERT_TRUE(at::allclose(op(a, vec), tensorlist_op(a, vec)));
|
|
}
|
|
|
|
#endif
|
|
|
|
|
|
// TODO add these tests if needed
|
|
// test_once_differentiable
|
|
// test_sparse_backward
|
|
// test_save_output_nr
|
|
// test_free_deep_graph_pyfunction
|
|
// test_naughty_anomaly_access
|
|
// test_naughty_autograd-function_stashing_ctx
|
|
// test_custom_autograd_repeated_grad_grad
|
|
// test_return_leaf
|
|
// test_anomaly_detect_nan
|
|
// test_no_grad_copy
|