mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62563 Expose a pair of functions to Python users: torch.autograd.graph.set_saved_tensors_default_hooks(pack, unpack) and torch.autograd.graph.reset_saved_tensors_default_hooks(). These functions control the hooks applied to saved tensors: all tensors saved in that context will be packed using the pack function, then unpacked accordingly when needed. Currently, this works by simply calling register_hooks (cf #60975) directly at the end of the constructor of a SavedVariable. This could be optimized further by not performing the copy before registering default hooks, but this would require a small refactor. Edit: the refactor is done in #61927. A current limitation is that if users create tensors in this context, they will not be able to register additional hooks on the saved tensor. For instance, to perform something like #28997, one could define a pack function that saves to disk whenever the tensor size is too big and returns a filename, then unpack simply reads the content of the file and outputs a tensor, e.g.: ``` def pack(x): name = os.path.join(tmp_dir, str(uuid.uuid4())) torch.save(x, name) return name def unpack(name): return torch.load(name) ``` Relanding previous PR: https://github.com/pytorch/pytorch/pull/61834 Original PR led to timeout error in: https://www.internalfb.com/mast/job/yuguo-release_canary_offline_training-inlinecvrp_a-canary_offline_train_28a7ecfc Now passing: https://www.internalfb.com/mast/job/quach-release_canary_offline_training-inlinecvrp_a-canary_offline_train_9bb57e98 The difference with the new version is we don't need to acquire the GIL when calling `PyDefaultSavedVariableHooks::get_hooks`. Test Plan: Imported from OSS Reviewed By: iramazanli Differential Revision: D30045405 Pulled By: Varal7 fbshipit-source-id: 7f6c07af3a56fe8835d5edcc815c15ea4fb4e332
42 lines
1.2 KiB
C++
42 lines
1.2 KiB
C++
#pragma once
|
|
|
|
#include <torch/csrc/python_headers.h>
|
|
|
|
#include <torch/csrc/autograd/function.h>
|
|
#include <torch/csrc/autograd/engine.h>
|
|
|
|
bool THPEngine_initModule(PyObject *module);
|
|
|
|
namespace torch { namespace autograd { namespace python {
|
|
|
|
struct PythonEngine : public Engine {
|
|
static Engine& get_python_engine();
|
|
~PythonEngine() override;
|
|
void thread_init(int device,
|
|
const std::shared_ptr<ReadyQueue>& ready_queue,
|
|
bool should_increment) override;
|
|
void thread_on_exception(
|
|
std::shared_ptr<GraphTask> graph_task,
|
|
const std::shared_ptr<Node>& fn,
|
|
std::exception& e) override;
|
|
variable_list execute(
|
|
const edge_list& roots,
|
|
const variable_list& inputs,
|
|
bool keep_graph,
|
|
bool create_graph,
|
|
bool accumulate_grad,
|
|
const edge_list& outputs = {}) override;
|
|
|
|
c10::intrusive_ptr<at::ivalue::Future> execute_with_graph_task(
|
|
const std::shared_ptr<GraphTask>& graph_task,
|
|
std::shared_ptr<Node> graph_root,
|
|
InputBuffer&& input_buffer) override;
|
|
|
|
std::unique_ptr<AnomalyMetadata> make_anomaly_metadata() override;
|
|
std::unique_ptr<SavedVariableHooks> get_default_saved_variable_hooks() override;
|
|
private:
|
|
PythonEngine();
|
|
};
|
|
|
|
}}} // namespace torch::autograd::python
|