mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: This reverts #147804 due to internal revert. --- This diff reverts D70262395 Reviewed By: RossMcKenzie Differential Revision: D70318024 @diff-train-skip-merge Pull Request resolved: https://github.com/pytorch/pytorch/pull/148164 Approved by: https://github.com/xmfan
56 lines
2.0 KiB
C++
56 lines
2.0 KiB
C++
#pragma once
|
|
|
|
#include <torch/csrc/autograd/function_hook.h>
|
|
#include <torch/csrc/python_headers.h>
|
|
#include <torch/csrc/utils/object_ptr.h>
|
|
|
|
namespace torch::dynamo::autograd {
|
|
class SwapSavedVariables;
|
|
} // namespace torch::dynamo::autograd
|
|
|
|
namespace torch::autograd {
|
|
|
|
struct PyFunctionTensorPreHook : public FunctionPreHook {
|
|
PyFunctionTensorPreHook(PyObject* dict, size_t value_idx);
|
|
~PyFunctionTensorPreHook() override;
|
|
variable_list operator()(const variable_list& values) override;
|
|
void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override;
|
|
PyObject* dict;
|
|
size_t value_idx;
|
|
};
|
|
|
|
struct PyFunctionPreHook : public FunctionPreHook {
|
|
PyFunctionPreHook(PyObject* dict);
|
|
~PyFunctionPreHook() override;
|
|
variable_list operator()(const variable_list& values) override;
|
|
void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override;
|
|
PyObject* dict;
|
|
};
|
|
|
|
struct PyFunctionPostHook : public FunctionPostHook {
|
|
PyFunctionPostHook(PyObject* dict);
|
|
~PyFunctionPostHook() override;
|
|
variable_list operator()(
|
|
const variable_list& outputs,
|
|
const variable_list& inputs) override;
|
|
void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override;
|
|
PyObject* dict;
|
|
};
|
|
|
|
// PyFunctionTensorPostAccGradHooks is a dictionary of PostAccumulateGradHooks,
|
|
// and it is understandable if you are confused by why it's a subclass. We are
|
|
// simply following the precedent of PyFunctionPreHook and PyFunctionPostHook
|
|
// above to easily enroll into existing infrastructure.
|
|
struct PyFunctionTensorPostAccGradHooks : public PostAccumulateGradHook {
|
|
PyFunctionTensorPostAccGradHooks(PyObject* dict);
|
|
~PyFunctionTensorPostAccGradHooks() override;
|
|
void operator()(const Variable& tensor) override;
|
|
void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override;
|
|
void apply_with_saved(
|
|
Variable& tensor,
|
|
torch::dynamo::autograd::SwapSavedVariables& saved) override;
|
|
PyObject* dict;
|
|
};
|
|
|
|
} // namespace torch::autograd
|