Fix retains grad behavior after in-place (#79996)

See this doc: https://docs.google.com/document/d/1KiRdnoj6B4cI3yl017hTbCqcOGO1gWIpUf20sldipHM/edit#

Two issues (1) regarding hooks in general and (2) regarding retains grad hooks are fixed, Python hooks, which rely on a different mechanism are not discussed here:
- Hooks in cpp in general
  - (fixed) new hooks to registered to a newer version of the tensor no longer get applied to grad_fn
    associated with older version of the tensor when the first hook was ever registered
  - (unchanged) hooks registered to the older version of the tensor remain active on
- Retains grad hooks
  - (fixed) now get moved to the latest grad_fn. NB: To the user, retains_grad is not considered hooks
    or expected to behave like hooks (which we consider properties of the grad_fn) vs retains_gradness
    which is a property of the tensor.
- (not in this PR) Python hooks
  - (will fix) same issue as hooks in cpp where new hooks are being applied to grad_fn associated
    with the older version of the tensor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79996
Approved by: https://github.com/albanD
This commit is contained in:
soulitzer
2022-07-08 11:39:24 -04:00
committed by PyTorch MergeBot
parent e9b3bc2ead
commit 516f3198d6
4 changed files with 188 additions and 7 deletions

View File

@ -831,6 +831,111 @@ TEST(CustomAutogradTest, Hooks) {
ASSERT_THROWS_WITH(y.remove_hook(3), "Invalid index");
}
TEST(CustomAutogradTest, HooksInplace) {
auto a = torch::ones({5, 5}, torch::requires_grad()).clone();
int hook1_count = 0;
auto hook1 = ([&hook1_count](Variable grad) {
hook1_count++;
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2);
});
int hook2_count = 0;
auto hook2 = ([&hook2_count](Variable grad) {
hook2_count++;
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}));
});
a.register_hook(hook1);
a.mul_(2);
a.register_hook(hook2);
auto out = (a + 1).sum();
out.backward();
ASSERT_EQ(hook1_count, 1);
ASSERT_EQ(hook2_count, 1);
}
TEST(CustomAutogradTest, HooksInplaceWithRetainsGrad) {
auto a = torch::ones({5, 5}, torch::requires_grad()).clone();
int hook1_count = 0;
auto hook1 = ([&hook1_count](Variable grad) {
hook1_count++;
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2);
});
int hook2_count = 0;
auto hook2 = ([&hook2_count](Variable grad) {
hook2_count++;
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2);
});
int hook3_count = 0;
auto hook3 = ([&hook3_count](Variable grad) {
hook3_count++;
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}));
});
a.register_hook(hook1);
a.retain_grad();
a.register_hook(hook2);
a.mul_(2);
a.register_hook(hook3);
auto out = (a + 1).sum();
out.backward();
ASSERT_EQ(hook1_count, 1);
ASSERT_EQ(hook2_count, 1);
ASSERT_EQ(hook3_count, 1);
ASSERT_TRUE(a.retains_grad());
ASSERT_VARIABLE_EQ(a.grad(), torch::ones({5, 5}));
}
TEST(CustomAutogradTest, HooksInplaceTwiceWithRetainsGrad) {
auto a = torch::ones({5, 5}, torch::requires_grad()).clone();
int hook1_count = 0;
auto hook1 = ([&hook1_count](Variable grad) {
hook1_count++;
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 4);
});
int hook2_count = 0;
auto hook2 = ([&hook2_count](Variable grad) {
hook2_count++;
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 4);
});
int hook3_count = 0;
auto hook3 = ([&hook3_count](Variable grad) {
hook3_count++;
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}));
});
a.register_hook(hook1);
a.retain_grad();
a.register_hook(hook2);
a.mul_(2);
a.mul_(2);
a.register_hook(hook3);
auto out = (a + 1).sum();
out.backward();
ASSERT_EQ(hook1_count, 1);
ASSERT_EQ(hook2_count, 1);
ASSERT_EQ(hook3_count, 1);
ASSERT_TRUE(a.retains_grad());
ASSERT_VARIABLE_EQ(a.grad(), torch::ones({5, 5}));
}
TEST(CustomAutogradTest, HookNone) {
struct NoneGradientFunction : public Function<NoneGradientFunction> {
static variable_list forward(AutogradContext* ctx, Variable x, Variable y) {

View File

@ -836,6 +836,41 @@ class TestAutograd(TestCase):
out.backward()
self.assertEqual(input * 18, input.grad)
# NB: See test/cpp/api/autograd.cpp for more tests on the interaction between
# retains_grad and hooks in cpp. There's no point testing in python because
# Python hooks use a completely different mechanism.
def test_retain_grad_inplace(self):
a = torch.tensor([1.], requires_grad=True).clone()
a.retain_grad()
a.mul_(2)
a.sum().backward()
self.assertEqual(a.grad, torch.tensor([1.]))
a = torch.tensor([1.], requires_grad=True).clone()
a.retain_grad()
# Inplace multiple times is OK, the real test here would be in cpp though
# because the index here is always zero, having cpp hooks in addition,
# will force us to properly update the index
a.mul_(2)
a.mul_(2)
a.sum().backward()
self.assertEqual(a.grad, torch.tensor([1.]))
def test_retain_grad_inplace_over_view(self):
base = torch.tensor([1.], requires_grad=True).clone()
view = base[:]
view2 = base[:]
view.retain_grad()
view2.retain_grad()
view.mul_(2)
(view + view2).sum().backward()
# The old grad_fn, slice, wouldn't be part of the graph during backward
# so if the retains grad were not properly updated to the new grad_fn,
# the grad would still be None
self.assertEqual(view.grad, view2.grad)
self.assertEqual(view.grad, torch.tensor([1.]))
def test_retain_grad_cycle(self):
x = torch.ones(5, 5, requires_grad=True)

View File

@ -154,6 +154,40 @@ AutogradMeta* materialize_autograd_meta(const at::TensorBase& self) {
return get_autograd_meta(self);
}
void update_cpp_hooks_on_new_gradfn(
const at::TensorBase& self,
const std::shared_ptr<torch::autograd::Node>& new_fn) {
// This function is called whenever the grad_fn of the tensor is
// changed. We assume here that new_fn does not yet have hooks of
// its own
//
// This function does two things:
const auto& meta = impl::get_autograd_meta(self);
TORCH_INTERNAL_ASSERT(meta);
TORCH_INTERNAL_ASSERT(new_fn);
if (!self.retains_grad()) {
// (1) reset the list when grad_fn is updated, so new hooks don't
// get erroneously registered to the old grad_fn.
// Note that the old cpp_hooks_list_ is still kept alive by the
// old grad_fn so hooks registered to the older version of the tensor
// will continue to be active.
meta->cpp_hooks_list_ = nullptr;
return;
}
// (2) If there is a retains_grad hook registered, move that from the
// old cpp_hooks_list_ to the new one
auto idx = meta->retains_grad_;
auto new_list = std::make_shared<hooks_list>();
new_list->push_back(std::move((*meta->cpp_hooks_list_)[idx]));
(*meta->cpp_hooks_list_)[idx] = nullptr;
meta->cpp_hooks_list_ = new_list;
// Since this is a new list, 0 is the index of the retains_grad hook
meta->retains_grad_ = 0;
std::unique_ptr<FunctionPreHook> hook_ptr(
new CppFunctionPreHook(meta->cpp_hooks_list_, self.output_nr()));
new_fn->add_pre_hook(std::move(hook_ptr));
}
void rebase_history(const Variable& self, Edge gradient_edge) {
TORCH_INTERNAL_ASSERT(gradient_edge.function != nullptr);
auto diff_view_meta = get_view_autograd_meta(self);
@ -181,6 +215,8 @@ void rebase_history(const Variable& self, Edge gradient_edge) {
}
set_gradient_edge(self, std::move(gradient_edge));
// Pass both self and its grad_fn to avoid calling into grad_fn reentrantly
torch::autograd::impl::update_cpp_hooks_on_new_gradfn(self, self.grad_fn());
}
void create_cpp_hook(const at::TensorBase& self) {
@ -495,7 +531,7 @@ void VariableHooks::retain_grad(const at::TensorBase& self) const {
if (self.is_leaf()) { // no-op for leaves
return;
}
if (impl::get_autograd_meta(self)->retains_grad_) {
if (impl::get_autograd_meta(self)->retains_grad_ != -1) {
return;
}
c10::weak_intrusive_ptr<c10::TensorImpl> weak_self(self.getIntrusivePtr());
@ -517,13 +553,13 @@ void VariableHooks::retain_grad(const at::TensorBase& self) const {
}
};
at::OptionalTensorRef(self)->register_hook(retain_grad_hook);
impl::get_autograd_meta(self)->retains_grad_ = true;
auto idx = at::OptionalTensorRef(self)->register_hook(retain_grad_hook);
impl::get_autograd_meta(self)->retains_grad_ = idx;
}
bool VariableHooks::retains_grad(const at::TensorBase& self) const {
if (impl::get_autograd_meta(self)) {
return impl::get_autograd_meta(self)->retains_grad_;
return impl::get_autograd_meta(self)->retains_grad_ != -1;
} else {
return false;
}
@ -661,6 +697,9 @@ const std::shared_ptr<torch::autograd::Node>& VariableHooks::grad_fn(
diff_view_meta->grad_fn_ = std::move(fn);
}
diff_view_meta->set_attr_version(current_version);
torch::autograd::impl::update_cpp_hooks_on_new_gradfn(
self, diff_view_meta->grad_fn_);
}
return diff_view_meta->grad_fn_;
}

View File

@ -222,8 +222,10 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
// Only meaningful on leaf variables (must be false otherwise)
bool requires_grad_;
// Only meaningful on non-leaf variables (must be false otherwise)
bool retains_grad_;
// Only meaningful on non-leaf variables (must be -1 otherwise)
// The value of retains_grad_ indicates the index of it in cpp_hooks_list_
// A value of -1 indicates that the tensor does not retain grad
int64_t retains_grad_;
bool is_view_;
@ -281,7 +283,7 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
Edge gradient_edge = Edge()) {
grad_fn_ = std::move(gradient_edge.function);
requires_grad_ = false;
retains_grad_ = false;
retains_grad_ = -1;
is_view_ = false;
output_nr_ = gradient_edge.input_nr;