Files
pytorch/test/cpp/api/autograd.cpp
richard 382ef1fda7 Autograd graphtask trim unnecessary edges (#82544)
### Introduction
<!-- What did you change and why was it needed? -->

Removing unnecessary weight gradient calculation is very important for applications that need high-order derivatives during training. However, this is not supported by the current Autograd engine.

For more detail: The backward function of a `matmul` operator (e.g., `linear` `addmm` `mm`), has two matmuls, one for `input gradient` and another for `weight gradient`. For a typical neural network (nn) with a few linear layers and activation functions, if the user calls `torch.autograd.grad()` to calculate the derivative of the nn output `y` w.r.t the nn input `x`,  only the `input gradient` of the `matmul` operator is needed, and the `weight gradient` is discarded. However, the current PyTorch autograd engine will always calculate the `weight gradient` if `weight` requires gradient (the calculation of the high-order derivative is performed during training).

The figure attached shows the autograd graph of the following code snippet:
```py
y = torch.nn.functional.linear(x, weight, bias)
y = y.pow(2)
# first order derivative
y__x, = torch.autograd.grad(y, x, grad_outputs=grad_outputs, create_graph=True)
# first order derivative
y__x__x, = torch.autograd.grad(y__x, x, grad_outputs=grad_outputs, create_graph=True)
```
The path with  is not needed when calculating derivatives.

<img width="50%" alt="image" src="https://user-images.githubusercontent.com/9999318/182018117-719c5a23-bcc6-4a63-8e8d-1bca3ebda2e3.png">

### Issue
<!-- Link to Issue ticket or RFP -->
Related issue: https://github.com/pytorch/pytorch/issues/56500

### Method
When calling `torch.autograd.grad`, `exec_info_` is created for each GraphTask, which allows filtering paths on the graph that are not needed. However, when the GraphTask calls into the node, the node still does not know whether the edges are needed or not. In the case of matmul, `weight.requires_grad is True` so the weight gradient is always calculated.

Following https://github.com/pytorch/pytorch/issues/56500#issuecomment-825694656, this PR passes the graph task's thread_local `exec_info_` into the node, so it could trim unnecessary edges during `torch.autograd.grad` calls.

### Benchmark
Benchmark script: https://gist.github.com/yueyericardo/24158433a2021c51eeef9c3e2722df99

Benchmark result:
6 hidden layers, batch size 10000, on A100

FP32 result
| hessian benchmark             | FP32 (before) | FP32 (After)      | FP32 (Functorch v0.1.1) |
| ----------------------------- | ------------- | ----------------- | ----------------------- |
| Linear + ReLU (no backward)   | 55.658 ms     | 29.392 ms (1.90X) | 29.547 ms (1.90X)       |
| Linear + ReLU (with backward) | 81.173 ms     | 54.917 ms (1.47X) | 68.988 ms (1.18X)       |

TF32 result
| hessian benchmark             | TF32 (before) | TF32 (after)      | TF32 (Functorch v0.1.1) |
| ----------------------------- | ------------- | ----------------- | ----------------------- |
| Linear + ReLU (no backward)   | 19.801 ms     | 11.259 ms (1.76X) | 10.754 ms (1.84X)       |
| Linear + ReLU (with backward) | 29.167 ms     | 20.466 ms (1.42X) | 22.784 ms (1.28X)       |

For FP32 result, we could get 1.9X speed up for hessian calculation, and 1.47X speed up during training, which is even faster than functorch `vmap(jacfwd(jacrev` implementation. (functorch has performance regression on v0.2.0, https://github.com/pytorch/functorch/issues/989, so we are using v0.1.1 for benchmark)

@zou3519 does functorch also includes similar optimizations during hessian calculation? If not, what do we need to do so the functorch could also benefit from this PR?

### Testing
<!-- How did you test your change? -->

- [x] we need to figure out a way for unittest

### Thanks
Thanks for the great blog: [How Computational Graphs are Executed in PyTorch | PyTorch](https://pytorch.org/blog/how-computational-graphs-are-executed-in-pytorch/)

cc @zasdfgbnm @albanD
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82544
Approved by: https://github.com/soulitzer
2022-08-11 18:50:09 +00:00

1605 lines
49 KiB
C++

#include <ATen/core/boxing/impl/test_helpers.h>
#include <gtest/gtest.h>
#include <ATen/core/op_registration/op_registration.h>
#include <torch/torch.h>
#include <torch/csrc/autograd/FunctionsManual.h>
#include <torch/csrc/autograd/functions/basic_ops.h>
#include <test/cpp/api/support.h>
using namespace torch::autograd;
using namespace torch::test;
#define ASSERT_VARIABLE_EQ(a, b) ASSERT_TRUE(torch::allclose((a), (b)))
#define EXPECT_VARIABLE_EQ(a, b) EXPECT_TRUE(torch::allclose((a), (b)))
std::string graph_desc(std::shared_ptr<Node> node) {
if (!node) {
return "None";
}
auto result = node->name() + "(";
auto next_edges = node->next_edges();
for (auto& edge : next_edges) {
result += graph_desc(edge.function);
}
return result + ")";
}
Variable simple_fn(const Variable& x, const Variable& y) {
return x + 2 * y + x * y;
}
TEST(AutogradAPITests, BackwardSimpleTest) {
Variable x = torch::randn({2, 2}, torch::requires_grad());
Variable y = torch::randn({2, 2}, torch::requires_grad());
auto res = simple_fn(x, y);
backward({res.sum()}, {});
ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({2, 2}));
ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({2, 2}) * 2);
}
TEST(AutogradAPITests, BackwardTest) {
Variable x = torch::randn({2, 2}, torch::requires_grad());
Variable y = torch::randn({2, 2}, torch::requires_grad());
auto res = simple_fn(x, y);
backward({res}, {torch::ones({2, 2})}, {}, true);
backward({res}, {torch::ones({2, 2})});
ASSERT_VARIABLE_EQ(x.grad(), 2 * (y + torch::ones({2, 2})));
ASSERT_VARIABLE_EQ(y.grad(), 2 * (x + torch::ones({2, 2}) * 2));
}
TEST(AutogradAPITests, GradSimpleTest) {
// basic grad
Variable x = torch::randn({2, 2}, torch::requires_grad());
Variable y = torch::randn({2, 2}, torch::requires_grad());
auto res = simple_fn(x, y);
auto grad_res = grad({res}, {x, y}, {torch::ones({2, 2})});
ASSERT_VARIABLE_EQ(grad_res[0], y + torch::ones({2, 2}));
ASSERT_VARIABLE_EQ(grad_res[1], x + torch::ones({2, 2}) * 2);
}
TEST(AutogradAPITests, GradTest) {
Variable x = torch::randn({2, 2}, torch::requires_grad());
Variable y = torch::randn({2, 2}, torch::requires_grad());
auto res = simple_fn(x, y);
res.backward(torch::ones({2, 2}), false, true);
Variable x_grad = y + torch::ones({2, 2});
Variable y_grad = x + torch::ones({2, 2}) * 2;
ASSERT_VARIABLE_EQ(x.grad(), x_grad);
ASSERT_VARIABLE_EQ(y.grad(), y_grad);
Variable grad_sum = 2 * x.grad() + y.grad();
auto x_hv = grad({grad_sum}, {x}, {torch::ones({2, 2})}, {}, true);
ASSERT_VARIABLE_EQ(x_hv[0], torch::ones({2, 2}));
ASSERT_VARIABLE_EQ(x.grad(), x_grad);
ASSERT_VARIABLE_EQ(y.grad(), y_grad);
}
TEST(AutogradAPITests, GradNonLeafTest) {
Variable x_init = torch::randn({2, 2}, torch::requires_grad());
Variable x = x_init;
Variable y = torch::randn({2, 2}, torch::requires_grad());
Variable grad_output = torch::ones({2, 2});
for (int i = 0; i < 5; ++i) {
auto res = simple_fn(x, y);
auto input_grads = grad({res}, {x}, {grad_output}, {}, true);
Variable grad_x_expected = y + torch::ones({2, 2});
ASSERT_VARIABLE_EQ(input_grads[0], grad_x_expected);
ASSERT_FALSE(x.grad().defined());
ASSERT_FALSE(y.grad().defined());
x = x + 0.05 * input_grads[0];
}
float val_init = simple_fn(x_init, y).sum().item().toFloat();
float val_final = simple_fn(x, y).sum().item().toFloat();
ASSERT_TRUE(val_final > val_init);
x.backward(grad_output, false, true);
ASSERT_TRUE(x_init.grad().defined());
ASSERT_TRUE(y.grad().defined());
}
TEST(AutogradAPITests, GradUnreachableTest) {
Variable x = torch::ones({1}, torch::requires_grad());
Variable y = torch::ones({1}, torch::requires_grad());
Variable z = x * 2;
Variable w = y * 2;
auto grad_res = grad({x * 2}, {x, y}, {}, {}, false, true);
ASSERT_VARIABLE_EQ(grad_res[0], x * 2);
ASSERT_FALSE(grad_res[1].defined());
// This is slightly different than the case above, because z doesn't even
// have a grad accumulator allocated.
z = torch::ones({1}, torch::requires_grad());
grad_res = grad({x * 2}, {x, z}, {}, {}, false, true);
ASSERT_VARIABLE_EQ(grad_res[0], x * 2);
ASSERT_FALSE(grad_res[1].defined());
// allow_unused=False, but grads contains None inside, should throw
ASSERT_THROWS_WITH(
grad({x * 2}, {x, y}, {}, {}, false, false), "Set allow_unused=True");
}
TEST(CustomAutogradTest, GradUnreachableDiscoveryTest) {
// Test that certain nodes are not erroneously executed when an input
// is unreachable. See #39784
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext* ctx, Variable var) {
return var;
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
ADD_FAILURE() << "This node should not be executed!";
return grad_output;
}
};
auto x = torch::randn(1, torch::requires_grad());
auto x1 = torch::randn(1);
auto x2 = MyFunction::apply(x + x1);
auto y = torch::randn(1, torch::requires_grad());
auto grad_res = torch::autograd::grad({x2}, {y}, {}, {}, false, true);
ASSERT_FALSE(grad_res[0].defined());
}
TEST(AutogradAPITests, EmptyInput) {
Variable x = torch::ones({1}, torch::requires_grad());
ASSERT_THROWS_WITH(
grad({x * 2}, /*inputs=*/{}, {x}), "grad requires non-empty inputs.");
}
TEST(AutogradAPITests, RetainGrad) {
auto input = torch::rand({1, 3}, torch::requires_grad());
auto h1 = input * 3;
auto out = (h1 * h1).sum();
{
// Warning when grad is accessed for non-leaf tensor
WarningCapture warnings;
ASSERT_FALSE(h1.grad().defined());
ASSERT_TRUE(warnings.str().find("is not a leaf") != std::string::npos);
}
// It should be possible to call retain_grad() multiple times
h1.retain_grad();
h1.retain_grad();
{
// If retain_grad is true for a non-leaf tensor,
// there should not be any warning when grad is accessed
WarningCapture warnings;
ASSERT_FALSE(h1.grad().defined());
ASSERT_FALSE(warnings.str().find("is not a leaf") != std::string::npos);
}
// Gradient should be accumulated
// NOLINTNEXTLINE(bugprone-argument-comment)
out.backward({}, /*keep_graph=*/true);
ASSERT_VARIABLE_EQ(h1 * 2, h1.grad());
// NOLINTNEXTLINE(bugprone-argument-comment)
out.backward({}, /*keep_graph=*/true);
ASSERT_VARIABLE_EQ(h1 * 4, h1.grad());
{
torch::NoGradGuard no_grad;
input.grad().zero_();
}
// It should be a no-op for leaves
input.retain_grad();
input.retain_grad();
out.backward();
ASSERT_VARIABLE_EQ(input * 18, input.grad());
}
TEST(AutogradAPITests, AnomalyMode) {
// Needs to have backtrace as warning and then throw an error
torch::autograd::DetectAnomalyGuard detect_anomaly;
{
WarningCapture warnings;
auto x = torch::tensor({5.0}, torch::requires_grad());
auto y = x * x;
auto z = y * y;
y += 1;
ASSERT_THROWS_WITH(z.backward(), "inplace");
ASSERT_TRUE(
warnings.str().find("Traceback of forward") != std::string::npos);
}
{
WarningCapture warnings;
// Double backward
auto x = torch::tensor({0.0}, torch::requires_grad());
auto y = x.pow(1.5);
auto gr =
// NOLINTNEXTLINE(bugprone-argument-comment)
grad({y}, {x}, {}, /*retain_graph=*/true, /*create_backward=*/true);
ASSERT_THROWS_WITH(grad({gr[0]}, {x}, {torch::tensor({0.0})});
, "returned nan");
auto msgs = warnings.messages();
ASSERT_EQ(msgs.size(), 2);
ASSERT_TRUE(
msgs[0].find("Traceback of forward call that caused the error") !=
std::string::npos);
ASSERT_TRUE(
msgs[1].find(
"Traceback of forward call that induced the previous calculation") !=
std::string::npos);
}
}
TEST(CustomAutogradTest, CustomFunction) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(
AutogradContext* ctx,
Variable var1,
int mul,
Variable var2) {
ctx->saved_data["mul"] = mul;
ctx->save_for_backward({var1, var2});
return var1 + mul * var2 + var1 * var2;
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
int mul = ctx->saved_data["mul"].toInt();
auto saved = ctx->get_saved_variables();
auto var1 = saved[0];
auto var2 = saved[1];
variable_list output = {
grad_output[0] + grad_output[0] * var2,
Variable(),
grad_output[0] * mul + grad_output[0] * var1};
return output;
}
};
Variable x = torch::randn({5, 5}, torch::requires_grad());
Variable y = torch::randn({5, 5}, torch::requires_grad());
auto res = MyFunction::apply(x, 2, y);
auto go = torch::ones({}, torch::requires_grad());
res.sum().backward(go, false, true);
ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({5, 5}));
ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({5, 5}) * 2);
}
TEST(CustomAutogradTest, GraphTaskTrimEdges) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(
AutogradContext* ctx,
Variable var1,
Variable var2,
int mul,
bool needs_input1_grad,
bool needs_input2_grad) {
// setup the expected should and should not compute idx
ctx->saved_data["needs_input1_grad"] = needs_input1_grad;
ctx->saved_data["needs_input2_grad"] = needs_input2_grad;
ctx->saved_data["mul"] = mul;
ctx->save_for_backward({var1, var2});
return var1 + mul * var2 + var1 * var2;
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
// Test `needs_input_grad` method is working correctly.
// We have to test this within the backward function.
auto needs_input1_grad = ctx->saved_data["needs_input1_grad"].toBool();
auto needs_input2_grad = ctx->saved_data["needs_input2_grad"].toBool();
IndexRange var1_idx = {0, 1};
IndexRange var2_idx = {1, 2};
EXPECT_EQ(ctx->needs_input_grad(0), needs_input1_grad);
EXPECT_EQ(ctx->needs_input_grad(1), needs_input2_grad);
EXPECT_EQ(ctx->needs_input_grad({var1_idx}), needs_input1_grad);
EXPECT_EQ(ctx->needs_input_grad({var2_idx}), needs_input2_grad);
EXPECT_EQ(
ctx->needs_input_grad({var1_idx, var2_idx}),
needs_input1_grad || needs_input2_grad);
// calculate gradients
int mul = ctx->saved_data["mul"].toInt();
auto saved = ctx->get_saved_variables();
auto var1 = saved[0];
auto var2 = saved[1];
Variable grad_var1, grad_var2;
if (ctx->needs_input_grad(0)) {
grad_var1 = grad_output[0] + grad_output[0] * var2;
}
if (ctx->needs_input_grad(1)) {
grad_var2 = grad_output[0] * mul + grad_output[0] * var1;
}
variable_list output = {
grad_var1,
grad_var2,
Variable(),
Variable(),
Variable(),
};
return output;
}
};
Variable x = torch::randn({5, 5}, torch::requires_grad());
Variable y = torch::randn({5, 5}, torch::requires_grad());
auto go = torch::ones_like(x);
Variable out;
// grad_x
out = MyFunction::apply(
x,
y,
2,
/* needs_input1_grad= */ true,
/* needs_input2_grad= */ false);
auto grad_x = torch::autograd::grad({out}, {x}, {go})[0];
ASSERT_VARIABLE_EQ(grad_x, y + torch::ones({5, 5}));
// grad_y
out = MyFunction::apply(
x,
y,
2,
/* needs_input1_grad= */ false,
/* needs_input2_grad= */ true);
auto grad_y = torch::autograd::grad({out}, {y}, {go})[0];
ASSERT_VARIABLE_EQ(grad_y, x + torch::ones({5, 5}) * 2);
// grad_x and grad_y
out = MyFunction::apply(
x,
y,
2,
/* needs_input1_grad= */ true,
/* needs_input2_grad= */ true);
auto grads = torch::autograd::grad({out}, {x, y}, {go});
ASSERT_VARIABLE_EQ(grads[0], y + torch::ones({5, 5}));
ASSERT_VARIABLE_EQ(grads[1], x + torch::ones({5, 5}) * 2);
}
TEST(CustomAutogradTest, FunctionReturnsInput) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext* ctx, Variable var1) {
return var1;
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
return {grad_output[0] * 2};
}
};
Variable x(torch::ones(1, torch::requires_grad()));
MyFunction::apply(x).backward(torch::ones(1), true, true);
ASSERT_VARIABLE_EQ(x.grad(), torch::full(1, 2.));
}
TEST(CustomAutogradTest, FunctionReturnsUndefined) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext* ctx, Variable var) {
return var * 2;
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
at::Tensor undefined_tensor;
return {undefined_tensor};
}
};
auto x = torch::ones(1, torch::requires_grad());
MyFunction::apply(x).backward();
ASSERT_FALSE(x.grad().defined());
MyFunction::apply(x.pow(2)).backward();
ASSERT_FALSE(x.grad().defined());
MyFunction::apply(x).sum().backward();
ASSERT_FALSE(x.grad().defined());
ASSERT_FALSE(torch::autograd::grad(
{MyFunction::apply(x)}, {x}, {}, false, false, true)[0]
.defined());
}
TEST(CustomAutogradTest, MaterializeGrads) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext* ctx, Variable var) {
return var;
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
EXPECT_VARIABLE_EQ(grad_output[0], torch::zeros(1));
return grad_output;
}
};
auto x = torch::ones(1, torch::requires_grad());
UndefinedGrad().apply({MyFunction::apply(x)})[0].backward();
}
TEST(CustomAutogradTest, DontMaterializeGrads) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext* ctx, Variable var) {
ctx->set_materialize_grads(false);
return var;
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
EXPECT_FALSE(grad_output[0].defined());
return grad_output;
}
};
auto x = torch::ones(1, torch::requires_grad());
UndefinedGrad().apply({MyFunction::apply(x)})[0].backward();
}
TEST(CustomAutogradTest, NoGradCustomFunction) {
// Custom Function should respect grad mode
struct MyOp : public Function<MyOp> {
static Variable forward(AutogradContext* ctx, Variable x) {
return x + 1;
}
static variable_list backward(AutogradContext* ctx, variable_list dy) {
return dy;
}
};
auto x = torch::ones({5, 5}, torch::requires_grad());
{
at::NoGradGuard no_grad;
auto y = MyOp::apply(x);
ASSERT_FALSE(y.requires_grad());
}
}
TEST(CustomAutogradTest, MarkDirty) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext* ctx, Variable v) {
// Change the value inplace
auto v_data = v.data_ptr<float>();
v_data[0] = 2;
ctx->mark_dirty({v});
return v;
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
return {(grad_output[0] * 2.0)};
}
};
// Clone here because modifying leafs inplace is not allowed
auto x = torch::randn({5, 5}, torch::requires_grad()).clone();
auto version_before = x._version();
auto out = MyFunction::apply(x);
auto version_after = x._version();
ASSERT_TRUE(version_after >= (version_before + 1));
out.sum().backward();
}
TEST(CustomAutogradTest, MarkNonDifferentiable) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext* ctx, Variable v) {
Variable output = v > 0;
ctx->mark_non_differentiable({output});
return output;
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
return {(grad_output[0] * 0.0)};
}
};
auto x = torch::randn({5, 5}, torch::requires_grad());
auto mask = MyFunction::apply(x);
ASSERT_FALSE(mask.requires_grad());
auto y = x.masked_fill(mask, 0);
y.sum().backward();
}
TEST(CustomAutogradTest, MarkNonDifferentiableMixed) {
struct MyFunction : public Function<MyFunction> {
static variable_list forward(AutogradContext* ctx, Variable input) {
Variable a = input + 1;
Variable b = input + 2;
ctx->mark_non_differentiable({a});
return {a, b};
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
const Variable &grad_a = grad_output[0], &grad_b = grad_output[1];
EXPECT_VARIABLE_EQ(grad_a, torch::zeros({5, 5}));
EXPECT_VARIABLE_EQ(grad_b, torch::ones({5, 5}));
return {grad_b};
}
};
auto x = torch::randn({5, 5}, torch::requires_grad());
auto out = MyFunction::apply(x);
ASSERT_FALSE(out[0].requires_grad());
ASSERT_TRUE(out[1].requires_grad());
out[1].sum().backward();
ASSERT_VARIABLE_EQ(x.grad(), torch::ones({5, 5}));
}
TEST(CustomAutogradTest, MarkNonDifferentiableNone) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext* ctx, Variable input) {
auto output = input.clone();
ctx->mark_non_differentiable({output});
return output;
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_outputs) {
return {};
}
};
auto x = torch::randn({5, 5}, torch::requires_grad());
auto r = MyFunction::apply(x * x);
(r * x).sum().backward();
}
TEST(CustomAutogradTest, ReturnLeafInplace) {
struct Inplace : public Function<Inplace> {
static variable_list forward(AutogradContext* ctx, Variable a, Variable b) {
ctx->mark_dirty({a});
return {a.add_(b), b + 2};
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
return {grad_output[0], grad_output[0] + grad_output[1]};
}
};
Variable x = torch::randn({5, 5});
Variable y = torch::randn({5, 5}, torch::requires_grad());
auto out = Inplace::apply(x, y);
auto& q = out[0];
ASSERT_TRUE(torch::equal(q, x));
ASSERT_TRUE(q.requires_grad());
q.sum().backward();
ASSERT_VARIABLE_EQ(y.grad(), torch::ones({5, 5}));
}
TEST(CustomAutogradTest, ReturnDuplicateInplace) {
struct DoubleInplace : public Function<DoubleInplace> {
static variable_list forward(AutogradContext* ctx, Variable x) {
x.mul_(2);
ctx->mark_dirty({x});
return {x, x};
}
static variable_list backward(
AutogradContext* ctsx,
variable_list grad_outputs) {
return {grad_outputs[0] * 2 + grad_outputs[1] * 2};
}
};
auto x = torch::randn({5, 5}, torch::requires_grad());
ASSERT_THROWS_WITH(
DoubleInplace::apply(x), "leaf Variable that requires grad");
// TODO ASSERT_THROWS_WITH(DoubleInplace::apply(x.clone()[0]), "only one
// output");
auto out = DoubleInplace::apply(x.clone());
ASSERT_TRUE(torch::equal(out[0], out[1]));
}
TEST(CustomAutogradTest, ReturnDuplicate) {
struct DoubleDuplicate : public Function<DoubleDuplicate> {
static variable_list forward(AutogradContext* ctx, Variable x) {
auto output = x * 2;
return {output, output};
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_outputs) {
return {grad_outputs[0] * 2 + grad_outputs[1] * 2};
}
};
auto x = torch::randn({5, 5}, torch::requires_grad());
auto out = DoubleDuplicate::apply(x);
ASSERT_TRUE(torch::equal(out[0], out[1]));
}
TEST(CustomAutogradTest, SaveEmptyForBackward) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext* ctx, Variable input) {
ctx->save_for_backward({Variable(), input, Variable()});
return input * input;
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
auto saved = ctx->get_saved_variables();
EXPECT_FALSE(saved[0].defined());
EXPECT_FALSE(saved[2].defined());
return {saved[1] * 2 * grad_output[0]};
}
};
Variable x = torch::randn({5, 5}, torch::requires_grad());
auto y = MyFunction::apply(x);
y.sum().backward();
ASSERT_VARIABLE_EQ(x.grad(), 2 * x);
}
TEST(CustomAutogradTest, InvalidGradients) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext* ctx, Variable x) {
return x * 2;
}
static variable_list backward(
AutogradContext* ctsx,
variable_list grad_outputs) {
return {
torch::randn(10, torch::dtype(torch::kFloat).requires_grad(true))};
}
};
auto input1 =
torch::randn({5, 5}, torch::dtype(torch::kFloat).requires_grad(true));
ASSERT_THROWS_WITH(
MyFunction::apply(input1).sum().backward(), "expected shape");
auto input2 =
torch::randn(10, torch::dtype(torch::kDouble).requires_grad(true));
}
TEST(CustomAutogradTest, NoGradInput) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext*, Variable x) {
return x;
}
static variable_list backward(
AutogradContext*,
variable_list grad_outputs) {
return grad_outputs;
}
};
Variable x = torch::randn({5, 5}, torch::requires_grad());
Variable y;
{
at::NoGradGuard no_grad;
y = MyFunction::apply(x);
}
ASSERT_TRUE(x.requires_grad());
ASSERT_FALSE(y.grad_fn());
}
TEST(CustomAutogradTest, TooManyGrads) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext*, Variable input) {
return input;
}
static variable_list backward(AutogradContext*, variable_list grad_output) {
grad_output.insert(grad_output.end(), {Variable(), Variable()});
return grad_output;
}
};
}
TEST(CustomAutogradTest, DepNoGrad) {
struct F1 : public Function<F1> {
static variable_list forward(AutogradContext* ctx, Variable input) {
auto out = torch::randn(input.sizes());
ctx->mark_non_differentiable({out});
return {input, out};
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
return {grad_output[0]};
}
};
struct F2 : public Function<F2> {
static Variable forward(AutogradContext*, Variable input, Variable ignore) {
return input;
}
static variable_list backward(AutogradContext*, variable_list grad_output) {
return {grad_output[0], Variable()};
}
};
auto x = torch::randn(5, torch::requires_grad());
auto out = F1::apply(x);
Variable &a = out[0], &b = out[1];
b = b + 1; // Separate F1 and F2 by another operation
ASSERT_TRUE(a.requires_grad());
ASSERT_FALSE(b.requires_grad());
auto c = F2::apply(a, b);
c.backward(torch::ones(c.sizes()), false, false);
ASSERT_VARIABLE_EQ(x.grad(), torch::ones(x.sizes()));
}
TEST(CustomAutogradTest, Reentrant) {
static Variable y_data = torch::randn({2, 2});
struct Reenter : public Function<Reenter> {
static Variable forward(AutogradContext* ctx, Variable input) {
Variable output;
{
at::AutoGradMode enable_grad(true);
auto x = make_variable(input.tensor_data(), true);
auto y = make_variable(y_data.tensor_data(), true);
output = x * y;
ctx->saved_data["x"] = x;
ctx->saved_data["y"] = y;
ctx->saved_data["output_var"] = output;
}
return output.detach();
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
{
at::AutoGradMode enable_grad(true);
auto out = ctx->saved_data["output_var"].toTensor();
out.sum().backward();
}
return {ctx->saved_data["x"].toTensor().grad() * grad_output[0]};
}
};
auto x = torch::randn({2, 2}, torch::requires_grad());
auto out = Reenter::apply(x);
out.sum().backward();
ASSERT_VARIABLE_EQ(x.grad(), y_data);
}
// NOTE: If this fails for apparently unrelated reasons in TSAN be aware of
// the TSAN limit on mutex: https://github.com/google/sanitizers/issues/950
TEST(CustomAutogradTest, DeepReentrant) {
struct DeepReenter : public Function<DeepReenter> {
static Variable forward(AutogradContext* ctx, Variable x) {
{
at::AutoGradMode enable_grad(true);
ctx->saved_data["x"] = make_variable(x.tensor_data(), true) - 1;
}
return ctx->saved_data["x"].toTensor().detach();
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
if (!at::native::is_nonzero(ctx->saved_data["x"].toTensor())) {
return grad_output;
}
{
at::AutoGradMode enable_grad(true);
apply(ctx->saved_data["x"].toTensor())[0].sum().backward();
return grad_output;
}
}
};
// This should not stack overflow
auto v =
torch::tensor({8193}, torch::dtype(torch::kFloat).requires_grad(true));
DeepReenter::apply(v).sum().backward();
}
TEST(CustomAutogradTest, ReentrantPriority) {
static std::vector<int> order;
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext*, Variable x) {
return x;
}
static variable_list backward(AutogradContext*, variable_list grad) {
order.push_back(0);
return grad;
}
};
struct Reenter : public Function<Reenter> {
static Variable forward(AutogradContext* ctx, Variable x) {
{
at::AutoGradMode enable_grad(true);
ctx->saved_data["x"] = make_variable(x.tensor_data(), true) - 1;
}
return ctx->saved_data["x"].toTensor().detach();
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
order.push_back(1);
if (!at::native::is_nonzero(ctx->saved_data["x"].toTensor())) {
return grad_output;
}
{
at::AutoGradMode enable_grad(true);
apply(ctx->saved_data["x"].toTensor())[0].sum().backward();
return grad_output;
}
}
};
auto a = MyFunction::apply(
torch::tensor({6}, torch::dtype(torch::kFloat).requires_grad(true)));
auto b = Reenter::apply(
torch::tensor({9}, torch::dtype(torch::kFloat).requires_grad(true)));
auto v = a * b;
v.backward();
// All the reentrant tasks should be prioritized over the MyFunction backward
// task.
ASSERT_EQ(order.size(), 10);
ASSERT_EQ(std::count(order.begin(), order.end(), 1), 9);
ASSERT_EQ(order.back(), 0);
// Clear static variable in case test get executed in a loop
order.clear();
}
TEST(CustomAutogradTest, Hooks) {
Variable x = torch::ones({5, 5}, torch::requires_grad());
Variable y = torch::ones({5, 5}) * 4;
y.set_requires_grad(true);
int counter = 0;
std::function<void(int, Variable)> bw_hook(
[&counter](int inc, Variable grad) { counter += inc; });
Variable z = x * x + x * 2 + x * y + y;
x.register_hook([&bw_hook](Variable grad) { bw_hook(0, grad); });
auto hook_1 =
z.register_hook([&bw_hook](Variable grad) { bw_hook(1, grad); });
z.backward(torch::ones({5, 5}), true, true);
ASSERT_EQ(counter, 1);
auto hook_2 =
z.register_hook([&bw_hook](Variable grad) { bw_hook(2, grad); });
z.backward(torch::ones({5, 5}), true, true);
ASSERT_EQ(counter, 4);
z.remove_hook(hook_2);
z.backward(torch::ones({5, 5}), true, true);
ASSERT_EQ(counter, 5);
std::function<Variable(Variable)> bw_hook_modify(
[](Variable grad) { return grad.mul(2); });
z.remove_hook(hook_1);
z.register_hook(bw_hook_modify);
y.grad().zero_();
z.backward(torch::ones({5, 5}), true, false);
ASSERT_VARIABLE_EQ(y.grad(), (x + 1) * 2);
y.register_hook(bw_hook_modify);
y.grad().zero_();
z.backward(torch::ones({5, 5}), false, false);
ASSERT_VARIABLE_EQ(y.grad(), (x + 1) * 4);
ASSERT_THROWS_WITH(y.remove_hook(3), "Invalid index");
}
TEST(CustomAutogradTest, HooksInplace) {
auto a = torch::ones({5, 5}, torch::requires_grad()).clone();
int hook1_count = 0;
auto hook1 = ([&hook1_count](Variable grad) {
hook1_count++;
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2);
});
int hook2_count = 0;
auto hook2 = ([&hook2_count](Variable grad) {
hook2_count++;
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}));
});
a.register_hook(hook1);
a.mul_(2);
a.register_hook(hook2);
auto out = (a + 1).sum();
out.backward();
ASSERT_EQ(hook1_count, 1);
ASSERT_EQ(hook2_count, 1);
}
TEST(CustomAutogradTest, HooksInplaceWithRetainsGrad) {
auto a = torch::ones({5, 5}, torch::requires_grad()).clone();
int hook1_count = 0;
auto hook1 = ([&hook1_count](Variable grad) {
hook1_count++;
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2);
});
int hook2_count = 0;
auto hook2 = ([&hook2_count](Variable grad) {
hook2_count++;
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2);
});
int hook3_count = 0;
auto hook3 = ([&hook3_count](Variable grad) {
hook3_count++;
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}));
});
a.register_hook(hook1);
a.retain_grad();
a.register_hook(hook2);
a.mul_(2);
a.register_hook(hook3);
auto out = (a + 1).sum();
out.backward();
ASSERT_EQ(hook1_count, 1);
ASSERT_EQ(hook2_count, 1);
ASSERT_EQ(hook3_count, 1);
ASSERT_TRUE(a.retains_grad());
ASSERT_VARIABLE_EQ(a.grad(), torch::ones({5, 5}));
}
TEST(CustomAutogradTest, HooksInplaceTwiceWithRetainsGrad) {
auto a = torch::ones({5, 5}, torch::requires_grad()).clone();
int hook1_count = 0;
auto hook1 = ([&hook1_count](Variable grad) {
hook1_count++;
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 4);
});
int hook2_count = 0;
auto hook2 = ([&hook2_count](Variable grad) {
hook2_count++;
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 4);
});
int hook3_count = 0;
auto hook3 = ([&hook3_count](Variable grad) {
hook3_count++;
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}));
});
a.register_hook(hook1);
a.retain_grad();
a.register_hook(hook2);
a.mul_(2);
a.mul_(2);
a.register_hook(hook3);
auto out = (a + 1).sum();
out.backward();
ASSERT_EQ(hook1_count, 1);
ASSERT_EQ(hook2_count, 1);
ASSERT_EQ(hook3_count, 1);
ASSERT_TRUE(a.retains_grad());
ASSERT_VARIABLE_EQ(a.grad(), torch::ones({5, 5}));
}
TEST(CustomAutogradTest, HookNone) {
struct NoneGradientFunction : public Function<NoneGradientFunction> {
static variable_list forward(AutogradContext* ctx, Variable x, Variable y) {
return {x, y};
}
static variable_list backward(AutogradContext* ctx, variable_list grad) {
return {grad[0], Variable()};
}
};
bool was_called = false;
auto hook = ([&was_called](Variable grad) {
ASSERT_TRUE(grad.defined());
was_called = true;
});
auto x = torch::randn({5, 5}, torch::requires_grad());
auto y = torch::randn({5, 5});
auto out = NoneGradientFunction::apply(x, y);
Variable rx = x[0], ry = x[1];
rx.register_hook(hook);
ry.register_hook(hook);
(rx + ry).sum().backward();
ASSERT_TRUE(was_called);
}
TEST(CustomAutogradTest, BackwardWithInputs) {
Variable x = torch::randn({5, 5}, torch::requires_grad());
Variable y = torch::randn({5, 5}, torch::requires_grad());
Variable z = x * x + x * y + y * y;
Variable x_grad_expected = 2 * x + y;
Variable y_grad_expected = x + 2 * y;
z.backward(torch::ones({5, 5}), false, false, {x});
ASSERT_VARIABLE_EQ(x.grad(), x_grad_expected);
ASSERT_FALSE(y.grad().defined());
}
TEST(CustomAutogradTest, BackwardWithEmptyInputs) {
Variable x = torch::randn({5, 5}, torch::requires_grad());
Variable y = torch::randn({5, 5}, torch::requires_grad());
Variable z = x * x + x * y + y * y;
Variable x_grad_expected = 2 * x + y;
Variable y_grad_expected = x + 2 * y;
ASSERT_THROWS_WITH(
z.backward(torch::ones({5, 5}), false, false, std::vector<Variable>{}),
"cannot be empty");
}
TEST(CustomAutogradTest, BackwardWithNonLeafInputs) {
Variable x = torch::randn({5, 5}, torch::requires_grad());
Variable y = torch::randn({5, 5}, torch::requires_grad());
Variable z = x * x;
Variable w = y * z + x * y + y * y;
Variable x_grad_expected = 2 * x * y + y;
Variable z_grad_expected = y;
w.backward(torch::ones({5, 5}), false, false, std::vector<Variable>{x, z});
ASSERT_VARIABLE_EQ(x.grad(), x_grad_expected);
ASSERT_VARIABLE_EQ(z.grad(), z_grad_expected);
ASSERT_FALSE(y.grad().defined());
}
TEST(CustomAutogradTest, BackwardWithCreateGraphWarns) {
c10::Warning::WarnAlways guard(true);
torch::Tensor x = torch::randn({5, 5}).set_requires_grad(true);
auto z = x * x;
{
WarningCapture warnings;
z.backward(torch::ones({5, 5}), c10::nullopt, true);
ASSERT_TRUE(
warnings.str().find("Using backward() with create_graph=True") !=
std::string::npos);
}
{
WarningCapture warnings;
torch::autograd::backward({z}, {torch::ones({5, 5})}, c10::nullopt, true);
ASSERT_TRUE(
warnings.str().find("Using backward() with create_graph=True") !=
std::string::npos);
}
}
/**
* Tests for AutogradNotImplementedFallback
* - Check that we created the NotImplemented kernel when inputs require grad
* but when no inputs require grad, we should not create this node
* - check_inplace logic
* - view ops
* - TODO: Tests for debug-only checks? Don't need for now because CI doesn't
* test non-NDEBUG builds.
* - tensorlist input and output
* - multiple outputs / non-tensor output
* - rebase_history vs set_history
*/
namespace {
torch::Tensor inplace_op(
const torch::Tensor& self,
const torch::Tensor& other) {
return self.add_(other);
}
std::tuple<torch::Tensor, torch::Tensor> two_arg_inplace_op(
const torch::Tensor& self,
const torch::Tensor& other) {
other.add_(self);
self.add_(other);
return std::tuple<torch::Tensor, torch::Tensor>(self, other);
}
std::tuple<torch::Tensor, torch::Tensor> two_pairs_of_view_op(
const torch::Tensor& self,
const torch::Tensor& other) {
// This is not allowed. We test below that this calling into the boxed kernel
// will raise an error
return std::tuple<torch::Tensor, torch::Tensor>(self, other);
}
std::tuple<torch::Tensor, torch::Tensor> non_first_view_op(
const torch::Tensor& self,
const torch::Tensor& other) {
// This is not allowed. We test below that this calling into the boxed kernel
// will raise an error
return std::tuple<torch::Tensor, torch::Tensor>(self.clone(), other);
}
int64_t ret_single_non_tensor(
const torch::Tensor& self,
const torch::Tensor& other) {
return 12;
}
torch::Tensor opt_op(
const torch::Tensor& self,
const c10::optional<at::Tensor>& other) {
if (other.has_value()) {
return self + other.value();
} else {
return self.clone();
}
}
torch::Tensor my_custom_op(
const torch::Tensor& self,
const torch::Tensor& other) {
return self + other;
}
std::tuple<torch::Tensor, torch::Tensor, int64_t> ret_tuple_non_tensor(
const torch::Tensor& self,
const torch::Tensor& other) {
auto a = self - other;
auto b = self + other;
return std::tuple<torch::Tensor, torch::Tensor, int64_t>(a, b, 12);
}
torch::Tensor view_op(const torch::Tensor& self) {
return self;
}
torch::Tensor view_op_with_extra_arg(
const torch::Tensor& self,
const torch::Tensor& other) {
return self;
}
std::vector<torch::Tensor> ret_tensor_vector_view(
const torch::Tensor& self,
const torch::Tensor& other) {
return {self, self};
}
std::vector<at::Tensor> ret_tensor_vector(
const torch::Tensor& self,
const torch::Tensor& other) {
std::vector<at::Tensor> out;
out.push_back(self + other);
out.push_back(self - other);
return out;
}
torch::Tensor tensorlist_op(const torch::Tensor& self, at::TensorList other) {
const auto& res = self.clone();
for (const auto& t : other) {
res.add_(t);
}
return res;
}
#define REGISTER_TEST_OP(name, schema, fn) \
auto m = MAKE_TORCH_LIBRARY(_test); \
m.def(schema); \
auto m_autograd = MAKE_TORCH_LIBRARY_IMPL(_test, Autograd); \
auto m_cpu = MAKE_TORCH_LIBRARY_IMPL(_test, CPU); \
auto m_inplaceorview = MAKE_TORCH_LIBRARY_IMPL(_test, ADInplaceOrView); \
m_cpu.impl(name, c10::DispatchKey::CPU, TORCH_FN(fn)); \
m_autograd.impl( \
name, c10::DispatchKey::Autograd, autogradNotImplementedFallback()); \
m_inplaceorview.impl( \
name, \
c10::DispatchKey::ADInplaceOrView, \
autogradNotImplementedInplaceOrViewFallback());
template <typename F>
void assertBasicChecks(F op) {
auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
auto b = torch::tensor({1.}, {torch::kFloat32});
auto c = torch::tensor({1.}, {torch::kFloat32});
// If any inputs require grad,
auto out1 = op(a, b);
ASSERT_THROWS_WITH(out1.backward(), "is not implemented");
// # Should not have grad_fn if none require grad
auto out2 = op(b, c);
ASSERT_THROWS_WITH(
out2.backward(),
"element 0 of tensors does not require grad and does not have a grad_fn");
// TODO: Forward AD Tests?
}
} // namespace
// These tests trigger an MSVC bug in the internal arvr build
// Reproduce with: buck build @arvr/mode/win/opt
// //xplat/caffe2:autograd_libtorch_test_ovrsource It is probably caused by the
// lambda, see https://github.com/pytorch/pytorch/issues/48763
#if !defined(_MSC_VER)
TEST(TestAutogradNotImplementedFallback, RetSingleNonTensor) {
REGISTER_TEST_OP(
"ret_single_non_tensor",
"_test::ret_single_non_tensor(Tensor self, Tensor other) -> int",
ret_single_non_tensor);
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
"_test::ret_single_non_tensor", "");
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
return callOpUnboxed<int64_t, const torch::Tensor&, const torch::Tensor&>(
opHandle, _1, _2);
};
auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
auto b = torch::tensor({1.}, {torch::kFloat32});
ASSERT_EQ(op(a, b), ret_single_non_tensor(a, b));
}
TEST(TestAutogradNotImplementedFallback, InplaceOp) {
REGISTER_TEST_OP(
"inplace_op",
"_test::inplace_op(Tensor(a!) self, Tensor other) -> Tensor(a!)",
inplace_op);
auto opHandle =
c10::Dispatcher::singleton().findSchemaOrThrow("_test::inplace_op", "");
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
return callOpUnboxed<
torch::Tensor,
const torch::Tensor&,
const torch::Tensor&>(opHandle, _1, _2);
};
auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
auto b = torch::tensor({1.}, {torch::kFloat32});
// Check in-place
ASSERT_THROWS_WITH(
op(a, b),
"a leaf Variable that requires grad is being used in an in-place operation");
op(b, a);
a = a.clone();
b = b.clone();
auto c = op(a, b);
ASSERT_TRUE(torch::allclose(c, inplace_op(a, b)));
// Test in-place on view
auto base =
torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone();
auto view = base.view(-1);
auto t = torch::tensor({1.}, {torch::kFloat32});
torch::Tensor v_nograd;
{
c10::NoGradGuard guard;
v_nograd = base.view(-1);
op(v_nograd, t);
}
ASSERT_THROWS_WITH(op(v_nograd, t), "A view was created in no_grad mode");
ASSERT_EQ(op(view, t).unsafeGetTensorImpl(), view.unsafeGetTensorImpl());
ASSERT_THAT(
op(view, t).grad_fn()->name(), ::testing::HasSubstr("AsStridedBackward"));
}
TEST(TestAutogradNotImplementedFallback, DoubleInplaceOp) {
REGISTER_TEST_OP(
"two_arg_inplace_op",
"_test::two_arg_inplace_op(Tensor(a!) self, Tensor(b!) other) -> (Tensor(a!), Tensor(b!))",
two_arg_inplace_op);
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
"_test::two_arg_inplace_op", "");
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
return callOpUnboxed<
std::tuple<torch::Tensor, torch::Tensor>,
const torch::Tensor&,
const torch::Tensor&>(opHandle, _1, _2);
};
auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
auto b = torch::tensor({1.}, {torch::kFloat32});
// Both are modified in-place!
ASSERT_THROWS_WITH(
op(a, b),
"a leaf Variable that requires grad is being used in an in-place operation");
ASSERT_THROWS_WITH(
op(b, a),
"a leaf Variable that requires grad is being used in an in-place operation");
auto c =
torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone();
auto d =
torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone();
auto saved_version_c = c._version();
auto saved_version_d = d._version();
op(c, d);
ASSERT_NE(c._version(), saved_version_c);
ASSERT_NE(d._version(), saved_version_d);
}
TEST(TestAutogradNotImplementedFallback, OptOp) {
REGISTER_TEST_OP(
"opt_op", "_test::opt_op(Tensor self, Tensor? other) -> Tensor", opt_op);
auto opHandle =
c10::Dispatcher::singleton().findSchemaOrThrow("_test::opt_op", "");
auto op = [&](const torch::Tensor& _1,
const c10::optional<torch::Tensor>& _2) {
return callOpUnboxed<
torch::Tensor,
const torch::Tensor&,
const c10::optional<torch::Tensor>&>(opHandle, _1, _2);
};
auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
auto b = torch::tensor({1.}, {torch::kFloat32});
ASSERT_TRUE(torch::allclose(op(a, b), opt_op(a, b)));
ASSERT_TRUE(torch::allclose(op(a, {}), opt_op(a, {})));
}
TEST(TestAutogradNotImplementedFallback, OutOfPlaceAddition) {
REGISTER_TEST_OP(
"my_custom_op",
"_test::my_custom_op(Tensor self, Tensor other) -> Tensor",
my_custom_op);
auto opHandle =
c10::Dispatcher::singleton().findSchemaOrThrow("_test::my_custom_op", "");
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
return callOpUnboxed<
torch::Tensor,
const torch::Tensor&,
const torch::Tensor&>(opHandle, _1, _2);
};
assertBasicChecks(op);
}
TEST(TestAutogradNotImplementedFallback, RetTupleNonTensor) {
REGISTER_TEST_OP(
"ret_tuple_non_tensor",
"_test::ret_tuple_non_tensor(Tensor self, Tensor other) -> (Tensor, Tensor, int)",
ret_tuple_non_tensor);
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
"_test::ret_tuple_non_tensor", "");
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
torch::Tensor out0;
torch::Tensor out1;
int64_t out2;
auto out = callOpUnboxed<
std::tuple<torch::Tensor, torch::Tensor, int64_t>,
const torch::Tensor&,
const torch::Tensor&>(opHandle, _1, _2);
std::tie(out0, out1, out2) = std::move(out);
return out0;
};
assertBasicChecks(op);
}
TEST(TestAutogradNotImplementedFallback, ViewOp) {
REGISTER_TEST_OP(
"view_op", "_test::view_op(Tensor(a) self) -> Tensor(a)", view_op);
auto opHandle =
c10::Dispatcher::singleton().findSchemaOrThrow("_test::view_op", "");
auto op = [&](const torch::Tensor& _1) {
return callOpUnboxed<torch::Tensor, const torch::Tensor&>(opHandle, _1);
};
auto b = torch::tensor({1.}, {torch::kFloat32});
auto v = op(b);
ASSERT_TRUE(v.is_view());
ASSERT_EQ(v._base().unsafeGetTensorImpl(), b.unsafeGetTensorImpl());
auto b1 =
torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone();
auto v1 = op(b1);
ASSERT_TRUE(v1.is_view());
ASSERT_EQ(v1._base().unsafeGetTensorImpl(), b1.unsafeGetTensorImpl());
// Test inplace on view
auto t = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
// raise on rebase_history when it refreshes grad_fn
ASSERT_THROWS_WITH(
v1.add_(t), "which does not have a derivative implemented is forbidden");
// base should not be aware of the views, so this is still okay
b1.add_(t);
ASSERT_THROWS_WITH(
v1.grad_fn(),
"which does not have a derivative implemented is forbidden");
}
TEST(TestAutogradNotImplementedFallback, ViewOpWithExtraArg) {
REGISTER_TEST_OP(
"view_op_with_extra_arg",
"_test::view_op_with_extra_arg(Tensor(a) self, Tensor other) -> Tensor(a)",
view_op_with_extra_arg);
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
"_test::view_op_with_extra_arg", "");
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
return callOpUnboxed<
torch::Tensor,
const torch::Tensor&,
const torch::Tensor&>(opHandle, _1, _2);
};
assertBasicChecks(op);
auto a = torch::tensor({1.}, {torch::kFloat32});
auto b = torch::tensor({2.}, {torch::kFloat32});
auto out1 = op(a, b);
ASSERT_TRUE(out1.is_view());
ASSERT_EQ(out1._base().unsafeGetTensorImpl(), a.unsafeGetTensorImpl());
}
TEST(TestAutogradNotImplementedFallback, RetTensorVectorView) {
REGISTER_TEST_OP(
"ret_tensor_vector_view",
"_test::ret_tensor_vector_view(Tensor(a) self, Tensor other) -> Tensor[](a)",
ret_tensor_vector_view);
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
"_test::ret_tensor_vector_view", "");
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
return callOpUnboxed<
std::vector<at::Tensor>,
const torch::Tensor&,
const torch::Tensor&>(opHandle, _1, _2);
};
auto a = torch::tensor({1.}, {torch::kFloat32});
auto b = torch::tensor({1.}, {torch::kFloat32});
auto out = op(a, b);
ASSERT_TRUE(out[0].is_view());
ASSERT_EQ(out[0]._base().unsafeGetTensorImpl(), a.unsafeGetTensorImpl());
ASSERT_TRUE(out[1].is_view());
ASSERT_EQ(out[1]._base().unsafeGetTensorImpl(), a.unsafeGetTensorImpl());
}
TEST(TestAutogradNotImplementedFallback, DoubleViewOP) {
REGISTER_TEST_OP(
"two_pairs_of_view_op",
"_test::two_pairs_of_view_op(Tensor(a) self, Tensor(b) other) -> (Tensor(a), Tensor(b))",
two_pairs_of_view_op);
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
"_test::two_pairs_of_view_op", "");
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
return callOpUnboxed<
std::tuple<torch::Tensor, torch::Tensor>,
const torch::Tensor&,
const torch::Tensor&>(opHandle, _1, _2);
};
auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
auto b = torch::tensor({1.}, {torch::kFloat32});
ASSERT_THROWS_WITH(
op(a, b),
"Expected only a single output in the operator schema to have a non-write alias annotation");
}
TEST(TestAutogradNotImplementedFallback, NonFirstViewOP) {
REGISTER_TEST_OP(
"non_first_view_op",
"_test::non_first_view_op(Tensor self, Tensor(b) other) -> (Tensor, Tensor(b))",
non_first_view_op);
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
"_test::non_first_view_op", "");
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
return callOpUnboxed<
std::tuple<torch::Tensor, torch::Tensor>,
const torch::Tensor&,
const torch::Tensor&>(opHandle, _1, _2);
};
auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
auto b = torch::tensor({1.}, {torch::kFloat32});
ASSERT_THROWS_WITH(
op(a, b), "can only create view relationships between the first");
}
TEST(TestAutogradNotImplementedFallback, RetTensorVector) {
REGISTER_TEST_OP(
"ret_tensor_vector",
"_test::ret_tensor_vector(Tensor self, Tensor other) -> Tensor[]",
ret_tensor_vector);
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
"_test::ret_tensor_vector", "");
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
return callOpUnboxed<
std::vector<at::Tensor>,
const torch::Tensor&,
const torch::Tensor&>(opHandle, _1, _2)[0];
};
assertBasicChecks(op);
}
TEST(TestAutogradNotImplementedFallback, TensorlistOp) {
REGISTER_TEST_OP(
"tensorlist_op",
"_test::tensorlist_op(Tensor self, Tensor[] other) -> Tensor",
tensorlist_op);
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
"_test::tensorlist_op", "");
auto op = [&](torch::Tensor _1, at::TensorList _2) {
return callOpUnboxed<torch::Tensor, const torch::Tensor&, at::TensorList>(
opHandle, _1, _2);
};
auto a = torch::tensor({1.}, {torch::kFloat32});
auto b = torch::tensor({1.}, {torch::kFloat32});
auto c = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
std::vector<torch::Tensor> vec = {b, c};
auto out = op(a, vec);
ASSERT_THROWS_WITH(
torch::autograd::grad({out}, {vec[0]}),
"One of the differentiated Tensors does not require grad");
ASSERT_THROWS_WITH(
torch::autograd::grad({out}, {vec[1]}), "is not implemented");
ASSERT_TRUE(at::allclose(op(a, vec), tensorlist_op(a, vec)));
}
#endif
// TODO add these tests if needed
// test_once_differentiable
// test_sparse_backward
// test_save_output_nr
// test_free_deep_graph_pyfunction
// test_naughty_anomaly_access
// test_naughty_autograd-function_stashing_ctx
// test_custom_autograd_repeated_grad_grad
// test_return_leaf
// test_anomaly_detect_nan
// test_no_grad_copy