mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
* Improve Variable interface * Address comments from @apaszke and @colesbury * string ::operator= is not noexcept * Remove ir.h from tracer_state.h to improve build times * Make Variable a struct and pack SavedVariable fields * Implement as_variable_ref * grad_fn_ptr() -> grad_fn_unsafe() * Reduce hackiness of set_type hack * Include variable.h and edge.h in tracer_state.h because it uses them * class Variable -> struct Variable because Windows cant even * Make Variable::output_nr uint32_t instead of int * Add comment about tracing state * Replaced more static_cast<Variable&> and improve docs * Remove SavedVariable destructor and construct members in init list * Clarify docs for Variable * Variable::set_version -> set_version_counter
23 lines
505 B
C++
23 lines
505 B
C++
#pragma once
|
|
|
|
#include <vector>
|
|
|
|
// A hook that's called on gradients
|
|
|
|
namespace torch { namespace autograd {
|
|
|
|
struct Variable;
|
|
using variable_list = std::vector<Variable>;
|
|
|
|
struct FunctionPreHook {
|
|
virtual ~FunctionPreHook() {}
|
|
virtual variable_list operator()(const variable_list& grads) = 0;
|
|
};
|
|
|
|
struct FunctionPostHook {
|
|
virtual ~FunctionPostHook() {}
|
|
virtual variable_list operator()(const variable_list& grad_input, const variable_list& grad_output) = 0;
|
|
};
|
|
|
|
}} // namespace torch::autograd
|