mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix retains grad behavior after in-place (#79996)
See this doc: https://docs.google.com/document/d/1KiRdnoj6B4cI3yl017hTbCqcOGO1gWIpUf20sldipHM/edit# Two issues (1) regarding hooks in general and (2) regarding retains grad hooks are fixed, Python hooks, which rely on a different mechanism are not discussed here: - Hooks in cpp in general - (fixed) new hooks to registered to a newer version of the tensor no longer get applied to grad_fn associated with older version of the tensor when the first hook was ever registered - (unchanged) hooks registered to the older version of the tensor remain active on - Retains grad hooks - (fixed) now get moved to the latest grad_fn. NB: To the user, retains_grad is not considered hooks or expected to behave like hooks (which we consider properties of the grad_fn) vs retains_gradness which is a property of the tensor. - (not in this PR) Python hooks - (will fix) same issue as hooks in cpp where new hooks are being applied to grad_fn associated with the older version of the tensor Pull Request resolved: https://github.com/pytorch/pytorch/pull/79996 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
e9b3bc2ead
commit
516f3198d6
@ -831,6 +831,111 @@ TEST(CustomAutogradTest, Hooks) {
|
||||
ASSERT_THROWS_WITH(y.remove_hook(3), "Invalid index");
|
||||
}
|
||||
|
||||
TEST(CustomAutogradTest, HooksInplace) {
|
||||
auto a = torch::ones({5, 5}, torch::requires_grad()).clone();
|
||||
|
||||
int hook1_count = 0;
|
||||
auto hook1 = ([&hook1_count](Variable grad) {
|
||||
hook1_count++;
|
||||
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2);
|
||||
});
|
||||
|
||||
int hook2_count = 0;
|
||||
auto hook2 = ([&hook2_count](Variable grad) {
|
||||
hook2_count++;
|
||||
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}));
|
||||
});
|
||||
|
||||
a.register_hook(hook1);
|
||||
a.mul_(2);
|
||||
a.register_hook(hook2);
|
||||
|
||||
auto out = (a + 1).sum();
|
||||
out.backward();
|
||||
|
||||
ASSERT_EQ(hook1_count, 1);
|
||||
ASSERT_EQ(hook2_count, 1);
|
||||
}
|
||||
|
||||
TEST(CustomAutogradTest, HooksInplaceWithRetainsGrad) {
|
||||
auto a = torch::ones({5, 5}, torch::requires_grad()).clone();
|
||||
|
||||
int hook1_count = 0;
|
||||
auto hook1 = ([&hook1_count](Variable grad) {
|
||||
hook1_count++;
|
||||
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2);
|
||||
});
|
||||
|
||||
int hook2_count = 0;
|
||||
auto hook2 = ([&hook2_count](Variable grad) {
|
||||
hook2_count++;
|
||||
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2);
|
||||
});
|
||||
|
||||
int hook3_count = 0;
|
||||
auto hook3 = ([&hook3_count](Variable grad) {
|
||||
hook3_count++;
|
||||
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}));
|
||||
});
|
||||
|
||||
a.register_hook(hook1);
|
||||
a.retain_grad();
|
||||
a.register_hook(hook2);
|
||||
|
||||
a.mul_(2);
|
||||
a.register_hook(hook3);
|
||||
|
||||
auto out = (a + 1).sum();
|
||||
out.backward();
|
||||
|
||||
ASSERT_EQ(hook1_count, 1);
|
||||
ASSERT_EQ(hook2_count, 1);
|
||||
ASSERT_EQ(hook3_count, 1);
|
||||
|
||||
ASSERT_TRUE(a.retains_grad());
|
||||
ASSERT_VARIABLE_EQ(a.grad(), torch::ones({5, 5}));
|
||||
}
|
||||
|
||||
TEST(CustomAutogradTest, HooksInplaceTwiceWithRetainsGrad) {
|
||||
auto a = torch::ones({5, 5}, torch::requires_grad()).clone();
|
||||
|
||||
int hook1_count = 0;
|
||||
auto hook1 = ([&hook1_count](Variable grad) {
|
||||
hook1_count++;
|
||||
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 4);
|
||||
});
|
||||
|
||||
int hook2_count = 0;
|
||||
auto hook2 = ([&hook2_count](Variable grad) {
|
||||
hook2_count++;
|
||||
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 4);
|
||||
});
|
||||
|
||||
int hook3_count = 0;
|
||||
auto hook3 = ([&hook3_count](Variable grad) {
|
||||
hook3_count++;
|
||||
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}));
|
||||
});
|
||||
|
||||
a.register_hook(hook1);
|
||||
a.retain_grad();
|
||||
a.register_hook(hook2);
|
||||
|
||||
a.mul_(2);
|
||||
a.mul_(2);
|
||||
a.register_hook(hook3);
|
||||
|
||||
auto out = (a + 1).sum();
|
||||
out.backward();
|
||||
|
||||
ASSERT_EQ(hook1_count, 1);
|
||||
ASSERT_EQ(hook2_count, 1);
|
||||
ASSERT_EQ(hook3_count, 1);
|
||||
|
||||
ASSERT_TRUE(a.retains_grad());
|
||||
ASSERT_VARIABLE_EQ(a.grad(), torch::ones({5, 5}));
|
||||
}
|
||||
|
||||
TEST(CustomAutogradTest, HookNone) {
|
||||
struct NoneGradientFunction : public Function<NoneGradientFunction> {
|
||||
static variable_list forward(AutogradContext* ctx, Variable x, Variable y) {
|
||||
|
@ -836,6 +836,41 @@ class TestAutograd(TestCase):
|
||||
out.backward()
|
||||
self.assertEqual(input * 18, input.grad)
|
||||
|
||||
# NB: See test/cpp/api/autograd.cpp for more tests on the interaction between
|
||||
# retains_grad and hooks in cpp. There's no point testing in python because
|
||||
# Python hooks use a completely different mechanism.
|
||||
def test_retain_grad_inplace(self):
|
||||
a = torch.tensor([1.], requires_grad=True).clone()
|
||||
a.retain_grad()
|
||||
a.mul_(2)
|
||||
a.sum().backward()
|
||||
self.assertEqual(a.grad, torch.tensor([1.]))
|
||||
|
||||
a = torch.tensor([1.], requires_grad=True).clone()
|
||||
a.retain_grad()
|
||||
# Inplace multiple times is OK, the real test here would be in cpp though
|
||||
# because the index here is always zero, having cpp hooks in addition,
|
||||
# will force us to properly update the index
|
||||
a.mul_(2)
|
||||
a.mul_(2)
|
||||
a.sum().backward()
|
||||
self.assertEqual(a.grad, torch.tensor([1.]))
|
||||
|
||||
def test_retain_grad_inplace_over_view(self):
|
||||
base = torch.tensor([1.], requires_grad=True).clone()
|
||||
view = base[:]
|
||||
view2 = base[:]
|
||||
view.retain_grad()
|
||||
view2.retain_grad()
|
||||
view.mul_(2)
|
||||
(view + view2).sum().backward()
|
||||
|
||||
# The old grad_fn, slice, wouldn't be part of the graph during backward
|
||||
# so if the retains grad were not properly updated to the new grad_fn,
|
||||
# the grad would still be None
|
||||
self.assertEqual(view.grad, view2.grad)
|
||||
self.assertEqual(view.grad, torch.tensor([1.]))
|
||||
|
||||
def test_retain_grad_cycle(self):
|
||||
x = torch.ones(5, 5, requires_grad=True)
|
||||
|
||||
|
@ -154,6 +154,40 @@ AutogradMeta* materialize_autograd_meta(const at::TensorBase& self) {
|
||||
return get_autograd_meta(self);
|
||||
}
|
||||
|
||||
void update_cpp_hooks_on_new_gradfn(
|
||||
const at::TensorBase& self,
|
||||
const std::shared_ptr<torch::autograd::Node>& new_fn) {
|
||||
// This function is called whenever the grad_fn of the tensor is
|
||||
// changed. We assume here that new_fn does not yet have hooks of
|
||||
// its own
|
||||
//
|
||||
// This function does two things:
|
||||
const auto& meta = impl::get_autograd_meta(self);
|
||||
TORCH_INTERNAL_ASSERT(meta);
|
||||
TORCH_INTERNAL_ASSERT(new_fn);
|
||||
if (!self.retains_grad()) {
|
||||
// (1) reset the list when grad_fn is updated, so new hooks don't
|
||||
// get erroneously registered to the old grad_fn.
|
||||
// Note that the old cpp_hooks_list_ is still kept alive by the
|
||||
// old grad_fn so hooks registered to the older version of the tensor
|
||||
// will continue to be active.
|
||||
meta->cpp_hooks_list_ = nullptr;
|
||||
return;
|
||||
}
|
||||
// (2) If there is a retains_grad hook registered, move that from the
|
||||
// old cpp_hooks_list_ to the new one
|
||||
auto idx = meta->retains_grad_;
|
||||
auto new_list = std::make_shared<hooks_list>();
|
||||
new_list->push_back(std::move((*meta->cpp_hooks_list_)[idx]));
|
||||
(*meta->cpp_hooks_list_)[idx] = nullptr;
|
||||
meta->cpp_hooks_list_ = new_list;
|
||||
// Since this is a new list, 0 is the index of the retains_grad hook
|
||||
meta->retains_grad_ = 0;
|
||||
std::unique_ptr<FunctionPreHook> hook_ptr(
|
||||
new CppFunctionPreHook(meta->cpp_hooks_list_, self.output_nr()));
|
||||
new_fn->add_pre_hook(std::move(hook_ptr));
|
||||
}
|
||||
|
||||
void rebase_history(const Variable& self, Edge gradient_edge) {
|
||||
TORCH_INTERNAL_ASSERT(gradient_edge.function != nullptr);
|
||||
auto diff_view_meta = get_view_autograd_meta(self);
|
||||
@ -181,6 +215,8 @@ void rebase_history(const Variable& self, Edge gradient_edge) {
|
||||
}
|
||||
|
||||
set_gradient_edge(self, std::move(gradient_edge));
|
||||
// Pass both self and its grad_fn to avoid calling into grad_fn reentrantly
|
||||
torch::autograd::impl::update_cpp_hooks_on_new_gradfn(self, self.grad_fn());
|
||||
}
|
||||
|
||||
void create_cpp_hook(const at::TensorBase& self) {
|
||||
@ -495,7 +531,7 @@ void VariableHooks::retain_grad(const at::TensorBase& self) const {
|
||||
if (self.is_leaf()) { // no-op for leaves
|
||||
return;
|
||||
}
|
||||
if (impl::get_autograd_meta(self)->retains_grad_) {
|
||||
if (impl::get_autograd_meta(self)->retains_grad_ != -1) {
|
||||
return;
|
||||
}
|
||||
c10::weak_intrusive_ptr<c10::TensorImpl> weak_self(self.getIntrusivePtr());
|
||||
@ -517,13 +553,13 @@ void VariableHooks::retain_grad(const at::TensorBase& self) const {
|
||||
}
|
||||
};
|
||||
|
||||
at::OptionalTensorRef(self)->register_hook(retain_grad_hook);
|
||||
impl::get_autograd_meta(self)->retains_grad_ = true;
|
||||
auto idx = at::OptionalTensorRef(self)->register_hook(retain_grad_hook);
|
||||
impl::get_autograd_meta(self)->retains_grad_ = idx;
|
||||
}
|
||||
|
||||
bool VariableHooks::retains_grad(const at::TensorBase& self) const {
|
||||
if (impl::get_autograd_meta(self)) {
|
||||
return impl::get_autograd_meta(self)->retains_grad_;
|
||||
return impl::get_autograd_meta(self)->retains_grad_ != -1;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
@ -661,6 +697,9 @@ const std::shared_ptr<torch::autograd::Node>& VariableHooks::grad_fn(
|
||||
diff_view_meta->grad_fn_ = std::move(fn);
|
||||
}
|
||||
diff_view_meta->set_attr_version(current_version);
|
||||
|
||||
torch::autograd::impl::update_cpp_hooks_on_new_gradfn(
|
||||
self, diff_view_meta->grad_fn_);
|
||||
}
|
||||
return diff_view_meta->grad_fn_;
|
||||
}
|
||||
|
@ -222,8 +222,10 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
|
||||
// Only meaningful on leaf variables (must be false otherwise)
|
||||
bool requires_grad_;
|
||||
|
||||
// Only meaningful on non-leaf variables (must be false otherwise)
|
||||
bool retains_grad_;
|
||||
// Only meaningful on non-leaf variables (must be -1 otherwise)
|
||||
// The value of retains_grad_ indicates the index of it in cpp_hooks_list_
|
||||
// A value of -1 indicates that the tensor does not retain grad
|
||||
int64_t retains_grad_;
|
||||
|
||||
bool is_view_;
|
||||
|
||||
@ -281,7 +283,7 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
|
||||
Edge gradient_edge = Edge()) {
|
||||
grad_fn_ = std::move(gradient_edge.function);
|
||||
requires_grad_ = false;
|
||||
retains_grad_ = false;
|
||||
retains_grad_ = -1;
|
||||
is_view_ = false;
|
||||
output_nr_ = gradient_edge.input_nr;
|
||||
|
||||
|
Reference in New Issue
Block a user