mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149987 Approved by: https://github.com/jansel ghstack dependencies: #149647, #149709, #149651, #149897
33 lines
959 B
C++
33 lines
959 B
C++
#pragma once
|
|
#include <torch/csrc/autograd/function_hook.h>
|
|
#include <functional>
|
|
#include <memory>
|
|
|
|
namespace torch::autograd {
|
|
|
|
using hooks_list =
|
|
std::vector<std::function<at::TensorBase(const at::TensorBase&)>>;
|
|
|
|
struct CppFunctionTensorPreHook : public FunctionPreHook {
|
|
CppFunctionTensorPreHook(std::shared_ptr<hooks_list> hooks, size_t value_idx);
|
|
variable_list operator()(const variable_list& values) override;
|
|
|
|
std::shared_ptr<hooks_list> hooks_;
|
|
size_t value_idx_;
|
|
};
|
|
|
|
struct CppFunctionSingleTensorPreHook : public FunctionPreHook {
|
|
CppFunctionSingleTensorPreHook(
|
|
std::function<at::TensorBase(const at::TensorBase&)> hook,
|
|
size_t value_idx);
|
|
variable_list operator()(const variable_list& values) override;
|
|
|
|
void compiled_args(
|
|
torch::dynamo::autograd::CompiledNodeArgs& args) const override;
|
|
|
|
std::function<at::TensorBase(const at::TensorBase&)> hook_;
|
|
size_t value_idx_;
|
|
};
|
|
|
|
} // namespace torch::autograd
|