mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
## Before Previously, CA will always unpack all saved variables stored in the autograd graph before executing it. This meant that we can't capture unpack hooks as part of the CA graph, and they would fire out of order wrt to other backward hooks. For memory saving APIs built on top of saved tensor hooks like non-reentrant checkpointing and offloading, we couldn't achieve any savings because all activations would be recomputed/loaded and active at the same time, resulting in no-op. ## After We add unpack hooks into the CA graph so that they can be executed progressively. The python hook and hook input themselves are wrapped by non-traceable code, so CA polyfills the wrapping as: ```python # pseudocode class SavedVariable: def unpack(self): if self.hook: return self.hook(self.packed_data) else: return self.packed_data # This approach won't directly work when we add support for Forward AD or double-backward. ``` Directly executing the CA graph (without torch.compiling it) under checkpointing/offloading, memory profile is expected to stay the same as when using the eager autograd engine. If AOT backward is in the autograd graph, memory profile is expected to be better than the eager autograd engine, since we can now delay saved activations unpacking into the AOT backward's execution. All tests pass when running the CA graph directly, the remaining issues are in Dynamo. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147242 Approved by: https://github.com/jansel
37 lines
1.1 KiB
C++
37 lines
1.1 KiB
C++
#pragma once
|
|
|
|
#include <ATen/ATen.h>
|
|
#include <c10/core/SafePyObject.h>
|
|
#include <pybind11/pybind11.h>
|
|
#include <torch/csrc/Export.h>
|
|
#include <torch/csrc/autograd/python_variable.h>
|
|
#include <torch/csrc/autograd/saved_variable_hooks.h>
|
|
#include <torch/csrc/python_headers.h>
|
|
#include <torch/csrc/utils/pybind.h>
|
|
|
|
namespace py = pybind11;
|
|
|
|
namespace torch::autograd {
|
|
|
|
struct PySavedVariableHooks : public SavedVariableHooks {
|
|
PySavedVariableHooks(py::function& pack_hook, py::function& unpack_hook);
|
|
void call_pack_hook(const at::Tensor& tensor) override;
|
|
at::Tensor call_unpack_hook() override;
|
|
~PySavedVariableHooks() override;
|
|
std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
|
|
retrieve_unpack_hook_data() const override;
|
|
|
|
private:
|
|
PyObject* pack_hook_;
|
|
PyObject* unpack_hook_;
|
|
PyObject* data_ = nullptr;
|
|
};
|
|
|
|
struct PyDefaultSavedVariableHooks {
|
|
static void push_hooks(py::function& pack_hook, py::function& unpack_hook);
|
|
static void pop_hooks();
|
|
static std::unique_ptr<SavedVariableHooks> get_hooks();
|
|
};
|
|
|
|
} // namespace torch::autograd
|