Enable AutoGradMode in InferenceMode. (#56107)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56107

Test Plan: Imported from OSS

Reviewed By: pbelevich, driazati

Differential Revision: D27807137

Pulled By: ailzhang

fbshipit-source-id: bfacf11ec5a431589cec73d6371cac81b425a115
This commit is contained in:
Ailing Zhang
2021-04-19 10:22:20 -07:00
committed by Facebook GitHub Bot
parent 8881f504f1
commit 98162cb0bb
2 changed files with 48 additions and 2 deletions

View File

@ -40,6 +40,7 @@ namespace {
ASSERT_FALSE(c10::impl::tls_is_dispatch_keyset_included(c10::autograd_dispatch_keyset));
ASSERT_EQ(c10::impl::tls_is_dispatch_keyset_excluded(c10::autograd_dispatch_keyset), inference_mode);
ASSERT_EQ(c10::impl::tls_is_dispatch_key_included(c10::DispatchKey::InplaceOrView), !inference_mode);
ASSERT_EQ(GradMode::is_enabled(), !inference_mode);
}
}
@ -534,3 +535,34 @@ TEST(InferenceModeTest, TestComplexViewInNormalMode) {
tmp = torch::view_as_real(tmp);
ASSERT_TRUE(is_inference_tensor(tmp));
}
TEST(InferenceModeTest, TestCustomFunction) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext *ctx, Variable var1, int mul, Variable var2) {
ctx->saved_data["mul"] = mul;
ctx->save_for_backward({var1, var2});
return var1 + mul*var2 + var1*var2;
}
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
int mul = ctx->saved_data["mul"].toInt();
auto saved = ctx->get_saved_variables();
auto var1 = saved[0];
auto var2 = saved[1];
variable_list output = {grad_output[0] + grad_output[0]*var2, Variable(), grad_output[0] * mul + grad_output[0] * var1};
return output;
}
};
{
InferenceMode guard;
torch::Tensor var1 = torch::ones({3, 3}).set_requires_grad(true);
auto var2 = var1.clone();
int mul = 2;
// If InferenceMode didn't set NoGradGuard automatically, this line
// would error out when trying to save `var1` and `var2` for backward.
auto y = MyFunction::apply(var1, mul, var2);
torch::Tensor expected = var1 + mul * var2 + var1 * var2;
assert_tensor_equal(y, expected);
}
}