Files
pytorch/torch/csrc/autograd/function_hook.h
Peter Goldsborough 2d5fbe6e0d Improve Variable interface (#5127)
* Improve Variable interface

* Address comments from @apaszke and @colesbury

* string ::operator= is not noexcept

* Remove ir.h from tracer_state.h to improve build times

* Make Variable a struct and pack SavedVariable fields

* Implement as_variable_ref

* grad_fn_ptr() -> grad_fn_unsafe()

* Reduce hackiness of set_type hack

* Include variable.h and edge.h in tracer_state.h because it uses them

* class Variable -> struct Variable because Windows cant even

* Make Variable::output_nr uint32_t instead of int

* Add comment about tracing state

* Replaced more static_cast<Variable&> and improve docs

* Remove SavedVariable destructor and construct members in init list

* Clarify docs for Variable

* Variable::set_version -> set_version_counter
2018-02-12 23:26:26 -05:00

23 lines
505 B
C++

#pragma once
#include <vector>
// A hook that's called on gradients
namespace torch { namespace autograd {
struct Variable;
using variable_list = std::vector<Variable>;
struct FunctionPreHook {
virtual ~FunctionPreHook() {}
virtual variable_list operator()(const variable_list& grads) = 0;
};
struct FunctionPostHook {
virtual ~FunctionPostHook() {}
virtual variable_list operator()(const variable_list& grad_input, const variable_list& grad_output) = 0;
};
}} // namespace torch::autograd