mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Simplify retains grad hook implementation (#92604)
How the old retains_grad hooks was implemented: - retains_grad hooks are stored on the autograd_meta, as entries in a vector - upon registration, a wrapper hook CppFunctionTensorPreHook is created to wrap that vector, and then that wrapper hook is registered to the grad_fn, i.e., by appending it to a vector of retains_grad hooks on the grad_fn - upon in-place, for the old grad_fn we set the retains_grad hook to nullptr, so that even though the old grad_fn still references the vector, the vector contains a single nullptr. For the new grad_fn, we create a new wrapper hook around the vector (storing the single retains_grad hook) on autograd_meta. The new retains_grad hook implementation: - we store std::function by value, and we store it on the grad_fn rather than the autograd_meta - a single grad_fn can have multiple outputs, so it can potentially hold multiple retains_grad hooks. We use an unordered_map (previously a vector). - on in-place we remove the hook from the old grad_fn and put it in the new grad_fn (small implication of this change is that we we now need to have access to both the old grad_fn and new grad_fn, this isn't a problem) Other details: - CppFunctionTensorPreHook took a shared_ptr to vector of std::function. In our new implementation, we add a new wrapper hook CppFunctionSingleTensorPreHook, which takes a single std::function. Pull Request resolved: https://github.com/pytorch/pytorch/pull/92604 Approved by: https://github.com/albanD
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							71b1051230
						
					
				
				
					commit
					a112814a7f
				
			| @ -48,5 +48,22 @@ variable_list CppFunctionTensorPreHook::operator()( | ||||
|   return results; | ||||
| } | ||||
|  | ||||
| // NOLINTNEXTLINE(modernize-pass-by-value) | ||||
| CppFunctionSingleTensorPreHook::CppFunctionSingleTensorPreHook( | ||||
|     std::function<at::TensorBase(const at::TensorBase&)> hook, | ||||
|     int value_idx) | ||||
|     : hook_(hook), value_idx_(value_idx) {} | ||||
|  | ||||
| variable_list CppFunctionSingleTensorPreHook::operator()( | ||||
|     const variable_list& values) { | ||||
|   auto value = values[value_idx_]; | ||||
|   auto res = hook_(value); | ||||
|   TORCH_INTERNAL_ASSERT( | ||||
|       !res.defined(), | ||||
|       "CppFunctionSingleTensorPreHook currently only supports hooks that don't return"); | ||||
|   variable_list results(values); | ||||
|   return results; | ||||
| } | ||||
|  | ||||
| } // namespace autograd | ||||
| } // namespace torch | ||||
|  | ||||
| @ -19,5 +19,15 @@ struct CppFunctionTensorPreHook : public FunctionPreHook { | ||||
|   int value_idx_; | ||||
| }; | ||||
|  | ||||
| struct CppFunctionSingleTensorPreHook : public FunctionPreHook { | ||||
|   CppFunctionSingleTensorPreHook( | ||||
|       std::function<at::TensorBase(const at::TensorBase&)> hook, | ||||
|       int value_idx); | ||||
|   variable_list operator()(const variable_list& values) override; | ||||
|  | ||||
|   std::function<at::TensorBase(const at::TensorBase&)> hook_; | ||||
|   int value_idx_; | ||||
| }; | ||||
|  | ||||
| } // namespace autograd | ||||
| } // namespace torch | ||||
|  | ||||
| @ -757,8 +757,8 @@ static variable_list call_tensor_pre_hooks(Node& fn, variable_list inputs) { | ||||
|   for (const auto& hook : fn.tensor_pre_hooks()) { | ||||
|     inputs = (*hook)(inputs); | ||||
|   } | ||||
|   for (const auto& hook : fn.retains_grad_hooks()) { | ||||
|     inputs = (*hook)(inputs); | ||||
|   for (const auto& pair : fn.retains_grad_hooks()) { | ||||
|     inputs = (*pair.second)(inputs); | ||||
|   } | ||||
|   return inputs; | ||||
| } | ||||
|  | ||||
| @ -490,8 +490,16 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> { | ||||
|     tensor_pre_hooks_.push_back(std::move(pre_hook)); | ||||
|   } | ||||
|  | ||||
|   void add_retains_grad_hook(std::unique_ptr<FunctionPreHook>&& pre_hook) { | ||||
|     retains_grad_hooks_.push_back(std::move(pre_hook)); | ||||
|   void add_retains_grad_hook( | ||||
|       std::unique_ptr<FunctionPreHook>&& pre_hook, | ||||
|       int output_idx) { | ||||
|     retains_grad_hooks_[output_idx] = std::move(pre_hook); | ||||
|   } | ||||
|  | ||||
|   std::unique_ptr<FunctionPreHook> pop_retains_grad_hook(int output_idx) { | ||||
|     auto ret = std::move(retains_grad_hooks_[output_idx]); | ||||
|     retains_grad_hooks_.erase(output_idx); | ||||
|     return ret; | ||||
|   } | ||||
|  | ||||
|   const std::vector<std::unique_ptr<FunctionPreHook>>& pre_hooks() | ||||
| @ -508,7 +516,8 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> { | ||||
|     return tensor_pre_hooks_; | ||||
|   } | ||||
|  | ||||
|   std::vector<std::unique_ptr<FunctionPreHook>>& retains_grad_hooks() noexcept { | ||||
|   std::unordered_map<int, std::unique_ptr<FunctionPreHook>>& | ||||
|   retains_grad_hooks() noexcept { | ||||
|     return retains_grad_hooks_; | ||||
|   } | ||||
|  | ||||
| @ -636,7 +645,7 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> { | ||||
|   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) | ||||
|   std::vector<std::unique_ptr<FunctionPreHook>> tensor_pre_hooks_; | ||||
|   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) | ||||
|   std::vector<std::unique_ptr<FunctionPreHook>> retains_grad_hooks_; | ||||
|   std::unordered_map<int, std::unique_ptr<FunctionPreHook>> retains_grad_hooks_; | ||||
|   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) | ||||
|   std::vector<std::unique_ptr<FunctionPostHook>> post_hooks_; | ||||
|   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) | ||||
|  | ||||
| @ -86,8 +86,9 @@ int THPCppFunction_traverse(PyObject* self, visitproc visit, void* arg) { | ||||
|   // In theory this shouldn't be necessary, because retains_grad_hooks should | ||||
|   // not contain any PyFunctionTensorPreHooks. The alternative is to have a | ||||
|   // check that actually guarantees this. | ||||
|   for (const auto& hook : fn.retains_grad_hooks()) { | ||||
|     if (auto pyhook = dynamic_cast<PyFunctionTensorPreHook*>(hook.get())) { | ||||
|   for (const auto& pair : fn.retains_grad_hooks()) { | ||||
|     if (auto pyhook = | ||||
|             dynamic_cast<PyFunctionTensorPreHook*>(pair.second.get())) { | ||||
|       Py_VISIT(pyhook->dict); | ||||
|     } | ||||
|   } | ||||
|  | ||||
| @ -219,8 +219,9 @@ static int THPFunction_traverse(THPFunction* self, visitproc visit, void* arg) { | ||||
|       } | ||||
|     } | ||||
|     // See NOTE [retains_grad_hook PyObject traversal] | ||||
|     for (const auto& hook : cdata->retains_grad_hooks()) { | ||||
|       if (auto pyhook = dynamic_cast<PyFunctionTensorPreHook*>(hook.get())) { | ||||
|     for (const auto& pair : cdata->retains_grad_hooks()) { | ||||
|       if (auto pyhook = | ||||
|               dynamic_cast<PyFunctionTensorPreHook*>(pair.second.get())) { | ||||
|         Py_VISIT(pyhook->dict); | ||||
|       } | ||||
|     } | ||||
|  | ||||
| @ -158,42 +158,41 @@ AutogradMeta* materialize_autograd_meta(const at::TensorBase& self) { | ||||
|  | ||||
| void update_tensor_hooks_on_new_gradfn( | ||||
|     const at::TensorBase& self, | ||||
|     const std::shared_ptr<torch::autograd::Node>& old_fn, | ||||
|     const std::shared_ptr<torch::autograd::Node>& new_fn) { | ||||
|   // This function is called whenever the grad_fn of the tensor is | ||||
|   // changed. We assume here that new_fn does not yet have hooks of | ||||
|   // its own | ||||
|   // its own. | ||||
|   // | ||||
|   // This function does two things: | ||||
|   const auto& meta = impl::get_autograd_meta(self); | ||||
|   TORCH_INTERNAL_ASSERT(meta); | ||||
|   TORCH_INTERNAL_ASSERT(new_fn); | ||||
|   // (1) reset the list when grad_fn is updated, so new hooks don't | ||||
|   //     get erroneously registered to the old grad_fn. | ||||
|   //     Note that the old cpp_hooks_list_ is still kept alive by the | ||||
|   //     old grad_fn so hooks registered to the older version of the tensor | ||||
|   //     will continue to be active. | ||||
|   // (2) If there is a retains_grad hook registered, move that from the | ||||
|   //     old cpp_hooks_list_ to the new one | ||||
|   const auto& meta = impl::get_autograd_meta(self); | ||||
|   TORCH_INTERNAL_ASSERT(meta); | ||||
|   TORCH_INTERNAL_ASSERT(new_fn); | ||||
|   meta->cpp_hooks_list_ = nullptr; | ||||
|   const c10::impl::PyInterpreter* interp = | ||||
|       self.unsafeGetTensorImpl()->pyobj_slot()->pyobj_interpreter(); | ||||
|   if (interp) { | ||||
|     (*interp)->reset_backward_hooks(self.unsafeGetTensorImpl()); | ||||
|   } | ||||
|   // (2) If there is a retains_grad hook registered, move that from the | ||||
|   //     old cpp_hooks_list_ to the new one | ||||
|   if (self.retains_grad()) { | ||||
|     auto new_list = std::make_shared<hooks_list>(); | ||||
|     new_list->push_back(std::move((*meta->retains_grad_hooks_list_)[0])); | ||||
|     (*meta->retains_grad_hooks_list_)[0] = nullptr; | ||||
|     meta->retains_grad_hooks_list_ = new_list; | ||||
|     std::unique_ptr<FunctionPreHook> hook_ptr = | ||||
|         std::make_unique<CppFunctionTensorPreHook>( | ||||
|             meta->retains_grad_hooks_list_, self.output_nr()); | ||||
|     new_fn->add_retains_grad_hook(std::move(hook_ptr)); | ||||
|     TORCH_INTERNAL_ASSERT(old_fn); | ||||
|     auto out = old_fn->pop_retains_grad_hook(self.output_nr()); | ||||
|     TORCH_INTERNAL_ASSERT(out != nullptr); | ||||
|     new_fn->add_retains_grad_hook(std::move(out), self.output_nr()); | ||||
|   } | ||||
| } | ||||
|  | ||||
| void rebase_history(const Variable& self, Edge gradient_edge) { | ||||
|   TORCH_INTERNAL_ASSERT(gradient_edge.function != nullptr); | ||||
|   const auto& meta = impl::get_autograd_meta(self); | ||||
|   auto old_fn = meta != nullptr ? meta->grad_fn_ : nullptr; | ||||
|   auto diff_view_meta = get_view_autograd_meta(self); | ||||
|   if (diff_view_meta && diff_view_meta->has_bw_view()) { | ||||
|     // See NOTE [ View + Inplace detection ] | ||||
| @ -221,21 +220,11 @@ void rebase_history(const Variable& self, Edge gradient_edge) { | ||||
|   set_gradient_edge(self, std::move(gradient_edge)); | ||||
|   // Pass both self and its grad_fn to avoid calling into grad_fn reentrantly | ||||
|   torch::autograd::impl::update_tensor_hooks_on_new_gradfn( | ||||
|       self, self.grad_fn()); | ||||
|       self, old_fn, self.grad_fn()); | ||||
| } | ||||
|  | ||||
| void create_cpp_hook(const at::TensorBase& self, bool is_retains_grad_hook) { | ||||
|   const auto& fn = self.grad_fn(); | ||||
|   if (is_retains_grad_hook) { | ||||
|     std::shared_ptr<hooks_list>& list = | ||||
|         materialize_autograd_meta(self)->retains_grad_hooks_list_; | ||||
|     // NOLINTNEXTLINE(modernize-make-shared) | ||||
|     list.reset(new hooks_list()); | ||||
|     std::unique_ptr<FunctionPreHook> hook_ptr{ | ||||
|         new CppFunctionTensorPreHook(list, self.output_nr())}; | ||||
|     TORCH_INTERNAL_ASSERT(fn, "Expect grad_fn to be defined for retains_grad"); | ||||
|     fn->add_retains_grad_hook(std::move(hook_ptr)); | ||||
|   } else { | ||||
|   std::shared_ptr<hooks_list>& list = | ||||
|       materialize_autograd_meta(self)->cpp_hooks_list_; | ||||
|   // NOLINTNEXTLINE(modernize-make-shared) | ||||
| @ -250,7 +239,6 @@ void create_cpp_hook(const at::TensorBase& self, bool is_retains_grad_hook) { | ||||
|   if (fn) { | ||||
|     fn->add_tensor_pre_hook(std::move(hook_ptr)); | ||||
|   } | ||||
|   } | ||||
| } | ||||
|  | ||||
| void set_grad_accumulator( | ||||
| @ -529,24 +517,6 @@ int64_t VariableHooks::_version(const at::TensorBase& self) const { | ||||
|   return self.unsafeGetTensorImpl()->version_counter().current_version(); | ||||
| } | ||||
|  | ||||
| unsigned register_retains_grad_hook( | ||||
|     const at::TensorBase& self, | ||||
|     std::function<at::TensorBase(const at::TensorBase&)> hook) { | ||||
|   TORCH_CHECK( | ||||
|       self.requires_grad(), | ||||
|       "cannot retain grad on a variable that " | ||||
|       "doesn't require gradient"); | ||||
|   // NB: materialize_autograd_meta unnecessary due to requires grad check | ||||
|   auto& list = | ||||
|       torch::autograd::impl::get_autograd_meta(self)->retains_grad_hooks_list_; | ||||
|   if (!list) { | ||||
|     torch::autograd::impl::create_cpp_hook(self, /*is_retains_grad_hook=*/true); | ||||
|   } | ||||
|   unsigned idx = list->size(); | ||||
|   list->push_back(hook); | ||||
|   return idx; | ||||
| } | ||||
|  | ||||
| void VariableHooks::retain_grad(const at::TensorBase& self) const { | ||||
|   TORCH_CHECK( | ||||
|       self.requires_grad(), | ||||
| @ -583,7 +553,10 @@ void VariableHooks::retain_grad(const at::TensorBase& self) const { | ||||
|     return at::TensorBase{}; | ||||
|   }; | ||||
|  | ||||
|   register_retains_grad_hook(self, retain_grad_hook); | ||||
|   const auto& fn = self.grad_fn(); | ||||
|   std::unique_ptr<FunctionPreHook> hook_ptr{new CppFunctionSingleTensorPreHook( | ||||
|       std::move(retain_grad_hook), self.output_nr())}; | ||||
|   fn->add_retains_grad_hook(std::move(hook_ptr), self.output_nr()); | ||||
|   impl::get_autograd_meta(self)->retains_grad_ = true; | ||||
| } | ||||
|  | ||||
| @ -674,6 +647,7 @@ const std::shared_ptr<torch::autograd::Node>& VariableHooks::grad_fn( | ||||
|       return diff_view_meta->grad_fn_; | ||||
|     } | ||||
|     auto current_version = self._version(); | ||||
|     auto old_fn = diff_view_meta->grad_fn_; | ||||
|     if (diff_view_meta->get_attr_version() != current_version) { | ||||
|       // This is an indirect rebase_history due to another view or the base | ||||
|       // being modified inplace | ||||
| @ -735,7 +709,7 @@ const std::shared_ptr<torch::autograd::Node>& VariableHooks::grad_fn( | ||||
|       diff_view_meta->set_attr_version(current_version); | ||||
|  | ||||
|       torch::autograd::impl::update_tensor_hooks_on_new_gradfn( | ||||
|           self, diff_view_meta->grad_fn_); | ||||
|           self, old_fn, diff_view_meta->grad_fn_); | ||||
|     } | ||||
|     return diff_view_meta->grad_fn_; | ||||
|   } | ||||
|  | ||||
| @ -229,7 +229,6 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface { | ||||
|   // each other, so using both is not defined behavior. | ||||
|   std::vector<std::unique_ptr<FunctionPreHook>> hooks_; | ||||
|   std::shared_ptr<hooks_list> cpp_hooks_list_; | ||||
|   std::shared_ptr<hooks_list> retains_grad_hooks_list_; | ||||
|  | ||||
|   // Only meaningful on leaf variables (must be false otherwise) | ||||
|   bool requires_grad_; | ||||
|  | ||||
		Reference in New Issue
	
	Block a user