mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Simplify retains grad hook implementation (#92604)
How the old retains_grad hooks was implemented: - retains_grad hooks are stored on the autograd_meta, as entries in a vector - upon registration, a wrapper hook CppFunctionTensorPreHook is created to wrap that vector, and then that wrapper hook is registered to the grad_fn, i.e., by appending it to a vector of retains_grad hooks on the grad_fn - upon in-place, for the old grad_fn we set the retains_grad hook to nullptr, so that even though the old grad_fn still references the vector, the vector contains a single nullptr. For the new grad_fn, we create a new wrapper hook around the vector (storing the single retains_grad hook) on autograd_meta. The new retains_grad hook implementation: - we store std::function by value, and we store it on the grad_fn rather than the autograd_meta - a single grad_fn can have multiple outputs, so it can potentially hold multiple retains_grad hooks. We use an unordered_map (previously a vector). - on in-place we remove the hook from the old grad_fn and put it in the new grad_fn (small implication of this change is that we we now need to have access to both the old grad_fn and new grad_fn, this isn't a problem) Other details: - CppFunctionTensorPreHook took a shared_ptr to vector of std::function. In our new implementation, we add a new wrapper hook CppFunctionSingleTensorPreHook, which takes a single std::function. Pull Request resolved: https://github.com/pytorch/pytorch/pull/92604 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
71b1051230
commit
a112814a7f
@ -48,5 +48,22 @@ variable_list CppFunctionTensorPreHook::operator()(
|
||||
return results;
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(modernize-pass-by-value)
|
||||
CppFunctionSingleTensorPreHook::CppFunctionSingleTensorPreHook(
|
||||
std::function<at::TensorBase(const at::TensorBase&)> hook,
|
||||
int value_idx)
|
||||
: hook_(hook), value_idx_(value_idx) {}
|
||||
|
||||
variable_list CppFunctionSingleTensorPreHook::operator()(
|
||||
const variable_list& values) {
|
||||
auto value = values[value_idx_];
|
||||
auto res = hook_(value);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
!res.defined(),
|
||||
"CppFunctionSingleTensorPreHook currently only supports hooks that don't return");
|
||||
variable_list results(values);
|
||||
return results;
|
||||
}
|
||||
|
||||
} // namespace autograd
|
||||
} // namespace torch
|
||||
|
@ -19,5 +19,15 @@ struct CppFunctionTensorPreHook : public FunctionPreHook {
|
||||
int value_idx_;
|
||||
};
|
||||
|
||||
struct CppFunctionSingleTensorPreHook : public FunctionPreHook {
|
||||
CppFunctionSingleTensorPreHook(
|
||||
std::function<at::TensorBase(const at::TensorBase&)> hook,
|
||||
int value_idx);
|
||||
variable_list operator()(const variable_list& values) override;
|
||||
|
||||
std::function<at::TensorBase(const at::TensorBase&)> hook_;
|
||||
int value_idx_;
|
||||
};
|
||||
|
||||
} // namespace autograd
|
||||
} // namespace torch
|
||||
|
@ -757,8 +757,8 @@ static variable_list call_tensor_pre_hooks(Node& fn, variable_list inputs) {
|
||||
for (const auto& hook : fn.tensor_pre_hooks()) {
|
||||
inputs = (*hook)(inputs);
|
||||
}
|
||||
for (const auto& hook : fn.retains_grad_hooks()) {
|
||||
inputs = (*hook)(inputs);
|
||||
for (const auto& pair : fn.retains_grad_hooks()) {
|
||||
inputs = (*pair.second)(inputs);
|
||||
}
|
||||
return inputs;
|
||||
}
|
||||
|
@ -490,8 +490,16 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
|
||||
tensor_pre_hooks_.push_back(std::move(pre_hook));
|
||||
}
|
||||
|
||||
void add_retains_grad_hook(std::unique_ptr<FunctionPreHook>&& pre_hook) {
|
||||
retains_grad_hooks_.push_back(std::move(pre_hook));
|
||||
void add_retains_grad_hook(
|
||||
std::unique_ptr<FunctionPreHook>&& pre_hook,
|
||||
int output_idx) {
|
||||
retains_grad_hooks_[output_idx] = std::move(pre_hook);
|
||||
}
|
||||
|
||||
std::unique_ptr<FunctionPreHook> pop_retains_grad_hook(int output_idx) {
|
||||
auto ret = std::move(retains_grad_hooks_[output_idx]);
|
||||
retains_grad_hooks_.erase(output_idx);
|
||||
return ret;
|
||||
}
|
||||
|
||||
const std::vector<std::unique_ptr<FunctionPreHook>>& pre_hooks()
|
||||
@ -508,7 +516,8 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
|
||||
return tensor_pre_hooks_;
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<FunctionPreHook>>& retains_grad_hooks() noexcept {
|
||||
std::unordered_map<int, std::unique_ptr<FunctionPreHook>>&
|
||||
retains_grad_hooks() noexcept {
|
||||
return retains_grad_hooks_;
|
||||
}
|
||||
|
||||
@ -636,7 +645,7 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||
std::vector<std::unique_ptr<FunctionPreHook>> tensor_pre_hooks_;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||
std::vector<std::unique_ptr<FunctionPreHook>> retains_grad_hooks_;
|
||||
std::unordered_map<int, std::unique_ptr<FunctionPreHook>> retains_grad_hooks_;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||
std::vector<std::unique_ptr<FunctionPostHook>> post_hooks_;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||
|
@ -86,8 +86,9 @@ int THPCppFunction_traverse(PyObject* self, visitproc visit, void* arg) {
|
||||
// In theory this shouldn't be necessary, because retains_grad_hooks should
|
||||
// not contain any PyFunctionTensorPreHooks. The alternative is to have a
|
||||
// check that actually guarantees this.
|
||||
for (const auto& hook : fn.retains_grad_hooks()) {
|
||||
if (auto pyhook = dynamic_cast<PyFunctionTensorPreHook*>(hook.get())) {
|
||||
for (const auto& pair : fn.retains_grad_hooks()) {
|
||||
if (auto pyhook =
|
||||
dynamic_cast<PyFunctionTensorPreHook*>(pair.second.get())) {
|
||||
Py_VISIT(pyhook->dict);
|
||||
}
|
||||
}
|
||||
|
@ -219,8 +219,9 @@ static int THPFunction_traverse(THPFunction* self, visitproc visit, void* arg) {
|
||||
}
|
||||
}
|
||||
// See NOTE [retains_grad_hook PyObject traversal]
|
||||
for (const auto& hook : cdata->retains_grad_hooks()) {
|
||||
if (auto pyhook = dynamic_cast<PyFunctionTensorPreHook*>(hook.get())) {
|
||||
for (const auto& pair : cdata->retains_grad_hooks()) {
|
||||
if (auto pyhook =
|
||||
dynamic_cast<PyFunctionTensorPreHook*>(pair.second.get())) {
|
||||
Py_VISIT(pyhook->dict);
|
||||
}
|
||||
}
|
||||
|
@ -158,42 +158,41 @@ AutogradMeta* materialize_autograd_meta(const at::TensorBase& self) {
|
||||
|
||||
void update_tensor_hooks_on_new_gradfn(
|
||||
const at::TensorBase& self,
|
||||
const std::shared_ptr<torch::autograd::Node>& old_fn,
|
||||
const std::shared_ptr<torch::autograd::Node>& new_fn) {
|
||||
// This function is called whenever the grad_fn of the tensor is
|
||||
// changed. We assume here that new_fn does not yet have hooks of
|
||||
// its own
|
||||
// its own.
|
||||
//
|
||||
// This function does two things:
|
||||
const auto& meta = impl::get_autograd_meta(self);
|
||||
TORCH_INTERNAL_ASSERT(meta);
|
||||
TORCH_INTERNAL_ASSERT(new_fn);
|
||||
// (1) reset the list when grad_fn is updated, so new hooks don't
|
||||
// get erroneously registered to the old grad_fn.
|
||||
// Note that the old cpp_hooks_list_ is still kept alive by the
|
||||
// old grad_fn so hooks registered to the older version of the tensor
|
||||
// will continue to be active.
|
||||
// (2) If there is a retains_grad hook registered, move that from the
|
||||
// old cpp_hooks_list_ to the new one
|
||||
const auto& meta = impl::get_autograd_meta(self);
|
||||
TORCH_INTERNAL_ASSERT(meta);
|
||||
TORCH_INTERNAL_ASSERT(new_fn);
|
||||
meta->cpp_hooks_list_ = nullptr;
|
||||
const c10::impl::PyInterpreter* interp =
|
||||
self.unsafeGetTensorImpl()->pyobj_slot()->pyobj_interpreter();
|
||||
if (interp) {
|
||||
(*interp)->reset_backward_hooks(self.unsafeGetTensorImpl());
|
||||
}
|
||||
// (2) If there is a retains_grad hook registered, move that from the
|
||||
// old cpp_hooks_list_ to the new one
|
||||
if (self.retains_grad()) {
|
||||
auto new_list = std::make_shared<hooks_list>();
|
||||
new_list->push_back(std::move((*meta->retains_grad_hooks_list_)[0]));
|
||||
(*meta->retains_grad_hooks_list_)[0] = nullptr;
|
||||
meta->retains_grad_hooks_list_ = new_list;
|
||||
std::unique_ptr<FunctionPreHook> hook_ptr =
|
||||
std::make_unique<CppFunctionTensorPreHook>(
|
||||
meta->retains_grad_hooks_list_, self.output_nr());
|
||||
new_fn->add_retains_grad_hook(std::move(hook_ptr));
|
||||
TORCH_INTERNAL_ASSERT(old_fn);
|
||||
auto out = old_fn->pop_retains_grad_hook(self.output_nr());
|
||||
TORCH_INTERNAL_ASSERT(out != nullptr);
|
||||
new_fn->add_retains_grad_hook(std::move(out), self.output_nr());
|
||||
}
|
||||
}
|
||||
|
||||
void rebase_history(const Variable& self, Edge gradient_edge) {
|
||||
TORCH_INTERNAL_ASSERT(gradient_edge.function != nullptr);
|
||||
const auto& meta = impl::get_autograd_meta(self);
|
||||
auto old_fn = meta != nullptr ? meta->grad_fn_ : nullptr;
|
||||
auto diff_view_meta = get_view_autograd_meta(self);
|
||||
if (diff_view_meta && diff_view_meta->has_bw_view()) {
|
||||
// See NOTE [ View + Inplace detection ]
|
||||
@ -221,21 +220,11 @@ void rebase_history(const Variable& self, Edge gradient_edge) {
|
||||
set_gradient_edge(self, std::move(gradient_edge));
|
||||
// Pass both self and its grad_fn to avoid calling into grad_fn reentrantly
|
||||
torch::autograd::impl::update_tensor_hooks_on_new_gradfn(
|
||||
self, self.grad_fn());
|
||||
self, old_fn, self.grad_fn());
|
||||
}
|
||||
|
||||
void create_cpp_hook(const at::TensorBase& self, bool is_retains_grad_hook) {
|
||||
const auto& fn = self.grad_fn();
|
||||
if (is_retains_grad_hook) {
|
||||
std::shared_ptr<hooks_list>& list =
|
||||
materialize_autograd_meta(self)->retains_grad_hooks_list_;
|
||||
// NOLINTNEXTLINE(modernize-make-shared)
|
||||
list.reset(new hooks_list());
|
||||
std::unique_ptr<FunctionPreHook> hook_ptr{
|
||||
new CppFunctionTensorPreHook(list, self.output_nr())};
|
||||
TORCH_INTERNAL_ASSERT(fn, "Expect grad_fn to be defined for retains_grad");
|
||||
fn->add_retains_grad_hook(std::move(hook_ptr));
|
||||
} else {
|
||||
std::shared_ptr<hooks_list>& list =
|
||||
materialize_autograd_meta(self)->cpp_hooks_list_;
|
||||
// NOLINTNEXTLINE(modernize-make-shared)
|
||||
@ -251,7 +240,6 @@ void create_cpp_hook(const at::TensorBase& self, bool is_retains_grad_hook) {
|
||||
fn->add_tensor_pre_hook(std::move(hook_ptr));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void set_grad_accumulator(
|
||||
const Variable& self,
|
||||
@ -529,24 +517,6 @@ int64_t VariableHooks::_version(const at::TensorBase& self) const {
|
||||
return self.unsafeGetTensorImpl()->version_counter().current_version();
|
||||
}
|
||||
|
||||
unsigned register_retains_grad_hook(
|
||||
const at::TensorBase& self,
|
||||
std::function<at::TensorBase(const at::TensorBase&)> hook) {
|
||||
TORCH_CHECK(
|
||||
self.requires_grad(),
|
||||
"cannot retain grad on a variable that "
|
||||
"doesn't require gradient");
|
||||
// NB: materialize_autograd_meta unnecessary due to requires grad check
|
||||
auto& list =
|
||||
torch::autograd::impl::get_autograd_meta(self)->retains_grad_hooks_list_;
|
||||
if (!list) {
|
||||
torch::autograd::impl::create_cpp_hook(self, /*is_retains_grad_hook=*/true);
|
||||
}
|
||||
unsigned idx = list->size();
|
||||
list->push_back(hook);
|
||||
return idx;
|
||||
}
|
||||
|
||||
void VariableHooks::retain_grad(const at::TensorBase& self) const {
|
||||
TORCH_CHECK(
|
||||
self.requires_grad(),
|
||||
@ -583,7 +553,10 @@ void VariableHooks::retain_grad(const at::TensorBase& self) const {
|
||||
return at::TensorBase{};
|
||||
};
|
||||
|
||||
register_retains_grad_hook(self, retain_grad_hook);
|
||||
const auto& fn = self.grad_fn();
|
||||
std::unique_ptr<FunctionPreHook> hook_ptr{new CppFunctionSingleTensorPreHook(
|
||||
std::move(retain_grad_hook), self.output_nr())};
|
||||
fn->add_retains_grad_hook(std::move(hook_ptr), self.output_nr());
|
||||
impl::get_autograd_meta(self)->retains_grad_ = true;
|
||||
}
|
||||
|
||||
@ -674,6 +647,7 @@ const std::shared_ptr<torch::autograd::Node>& VariableHooks::grad_fn(
|
||||
return diff_view_meta->grad_fn_;
|
||||
}
|
||||
auto current_version = self._version();
|
||||
auto old_fn = diff_view_meta->grad_fn_;
|
||||
if (diff_view_meta->get_attr_version() != current_version) {
|
||||
// This is an indirect rebase_history due to another view or the base
|
||||
// being modified inplace
|
||||
@ -735,7 +709,7 @@ const std::shared_ptr<torch::autograd::Node>& VariableHooks::grad_fn(
|
||||
diff_view_meta->set_attr_version(current_version);
|
||||
|
||||
torch::autograd::impl::update_tensor_hooks_on_new_gradfn(
|
||||
self, diff_view_meta->grad_fn_);
|
||||
self, old_fn, diff_view_meta->grad_fn_);
|
||||
}
|
||||
return diff_view_meta->grad_fn_;
|
||||
}
|
||||
|
@ -229,7 +229,6 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
|
||||
// each other, so using both is not defined behavior.
|
||||
std::vector<std::unique_ptr<FunctionPreHook>> hooks_;
|
||||
std::shared_ptr<hooks_list> cpp_hooks_list_;
|
||||
std::shared_ptr<hooks_list> retains_grad_hooks_list_;
|
||||
|
||||
// Only meaningful on leaf variables (must be false otherwise)
|
||||
bool requires_grad_;
|
||||
|
Reference in New Issue
Block a user