mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 16:44:58 +08:00
Enable AutoGradMode in InferenceMode. (#56107)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56107 Test Plan: Imported from OSS Reviewed By: pbelevich, driazati Differential Revision: D27807137 Pulled By: ailzhang fbshipit-source-id: bfacf11ec5a431589cec73d6371cac81b425a115
This commit is contained in:
committed by
Facebook GitHub Bot
parent
8881f504f1
commit
98162cb0bb
@ -40,6 +40,7 @@ namespace {
|
||||
ASSERT_FALSE(c10::impl::tls_is_dispatch_keyset_included(c10::autograd_dispatch_keyset));
|
||||
ASSERT_EQ(c10::impl::tls_is_dispatch_keyset_excluded(c10::autograd_dispatch_keyset), inference_mode);
|
||||
ASSERT_EQ(c10::impl::tls_is_dispatch_key_included(c10::DispatchKey::InplaceOrView), !inference_mode);
|
||||
ASSERT_EQ(GradMode::is_enabled(), !inference_mode);
|
||||
}
|
||||
}
|
||||
|
||||
@ -534,3 +535,34 @@ TEST(InferenceModeTest, TestComplexViewInNormalMode) {
|
||||
tmp = torch::view_as_real(tmp);
|
||||
ASSERT_TRUE(is_inference_tensor(tmp));
|
||||
}
|
||||
|
||||
TEST(InferenceModeTest, TestCustomFunction) {
|
||||
struct MyFunction : public Function<MyFunction> {
|
||||
static Variable forward(AutogradContext *ctx, Variable var1, int mul, Variable var2) {
|
||||
ctx->saved_data["mul"] = mul;
|
||||
ctx->save_for_backward({var1, var2});
|
||||
return var1 + mul*var2 + var1*var2;
|
||||
}
|
||||
|
||||
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
|
||||
int mul = ctx->saved_data["mul"].toInt();
|
||||
auto saved = ctx->get_saved_variables();
|
||||
auto var1 = saved[0];
|
||||
auto var2 = saved[1];
|
||||
variable_list output = {grad_output[0] + grad_output[0]*var2, Variable(), grad_output[0] * mul + grad_output[0] * var1};
|
||||
return output;
|
||||
}
|
||||
};
|
||||
|
||||
{
|
||||
InferenceMode guard;
|
||||
torch::Tensor var1 = torch::ones({3, 3}).set_requires_grad(true);
|
||||
auto var2 = var1.clone();
|
||||
int mul = 2;
|
||||
// If InferenceMode didn't set NoGradGuard automatically, this line
|
||||
// would error out when trying to save `var1` and `var2` for backward.
|
||||
auto y = MyFunction::apply(var1, mul, var2);
|
||||
torch::Tensor expected = var1 + mul * var2 + var1 * var2;
|
||||
assert_tensor_equal(y, expected);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user