mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
This reverts commit ea12fc8a9ff7da808e0b661ca07e9d4ce75d04bc. Reland https://github.com/pytorch/pytorch/pull/147804, there was a bad import inserted by my linter. Differential Revision: [D70582747](https://our.internmc.facebook.com/intern/diff/D70582747) Pull Request resolved: https://github.com/pytorch/pytorch/pull/148376 Approved by: https://github.com/jansel
60 lines
2.0 KiB
C++
60 lines
2.0 KiB
C++
#pragma once
|
|
|
|
#include <torch/csrc/autograd/function_hook.h>
|
|
#include <torch/csrc/python_headers.h>
|
|
#include <torch/csrc/utils/object_ptr.h>
|
|
|
|
namespace torch::dynamo::autograd {
|
|
class SwapSavedVariables;
|
|
} // namespace torch::dynamo::autograd
|
|
|
|
namespace torch::autograd {
|
|
|
|
struct PyFunctionTensorPreHook : public FunctionPreHook {
|
|
PyFunctionTensorPreHook(PyObject* dict, size_t value_idx);
|
|
~PyFunctionTensorPreHook() override;
|
|
variable_list operator()(const variable_list& values) override;
|
|
void compiled_args(
|
|
torch::dynamo::autograd::CompiledNodeArgs& args) const override;
|
|
PyObject* dict;
|
|
size_t value_idx;
|
|
};
|
|
|
|
struct PyFunctionPreHook : public FunctionPreHook {
|
|
PyFunctionPreHook(PyObject* dict);
|
|
~PyFunctionPreHook() override;
|
|
variable_list operator()(const variable_list& values) override;
|
|
void compiled_args(
|
|
torch::dynamo::autograd::CompiledNodeArgs& args) const override;
|
|
PyObject* dict;
|
|
};
|
|
|
|
struct PyFunctionPostHook : public FunctionPostHook {
|
|
PyFunctionPostHook(PyObject* dict);
|
|
~PyFunctionPostHook() override;
|
|
variable_list operator()(
|
|
const variable_list& outputs,
|
|
const variable_list& inputs) override;
|
|
void compiled_args(
|
|
torch::dynamo::autograd::CompiledNodeArgs& args) const override;
|
|
PyObject* dict;
|
|
};
|
|
|
|
// PyFunctionTensorPostAccGradHooks is a dictionary of PostAccumulateGradHooks,
|
|
// and it is understandable if you are confused by why it's a subclass. We are
|
|
// simply following the precedent of PyFunctionPreHook and PyFunctionPostHook
|
|
// above to easily enroll into existing infrastructure.
|
|
struct PyFunctionTensorPostAccGradHooks : public PostAccumulateGradHook {
|
|
PyFunctionTensorPostAccGradHooks(PyObject* dict);
|
|
~PyFunctionTensorPostAccGradHooks() override;
|
|
void operator()(const Variable& tensor) override;
|
|
void compiled_args(
|
|
torch::dynamo::autograd::CompiledNodeArgs& args) const override;
|
|
void apply_with_saved(
|
|
Variable& tensor,
|
|
torch::dynamo::autograd::SwapSavedVariables& saved) override;
|
|
PyObject* dict;
|
|
};
|
|
|
|
} // namespace torch::autograd
|