mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 20:34:54 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			66 lines
		
	
	
		
			2.0 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			66 lines
		
	
	
		
			2.0 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| #pragma once
 | |
| 
 | |
| #include <ATen/Tensor.h>
 | |
| #include <torch/csrc/Export.h>
 | |
| #include <string>
 | |
| #include <vector>
 | |
| 
 | |
| namespace torch::dynamo::autograd {
 | |
| class CompiledNodeArgs;
 | |
| class SwapSavedVariables;
 | |
| struct SavedState;
 | |
| } // namespace torch::dynamo::autograd
 | |
| 
 | |
| // A hook that's called on gradients
 | |
| 
 | |
| namespace torch::autograd {
 | |
| 
 | |
| using Variable = at::Tensor;
 | |
| using variable_list = std::vector<Variable>;
 | |
| 
 | |
| struct TORCH_API FunctionPreHook {
 | |
|   virtual ~FunctionPreHook() = default;
 | |
|   virtual variable_list operator()(const variable_list& grads) = 0;
 | |
|   // only implemented for python hooks, registers hook with compiled autograd
 | |
|   virtual void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) {
 | |
|     throw std::runtime_error(
 | |
|         std::string("compiled_args nyi, see [Note: Compiled Autograd] ") +
 | |
|         typeid(*this).name());
 | |
|   }
 | |
| };
 | |
| 
 | |
| struct TORCH_API FunctionPostHook {
 | |
|   virtual ~FunctionPostHook() = default;
 | |
|   virtual variable_list operator()(
 | |
|       const variable_list& outputs /* grad_inputs */,
 | |
|       const variable_list& inputs /* grad_outputs */) = 0;
 | |
|   // only implemented for python hooks, registers hook with compiled autograd
 | |
|   virtual void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) {
 | |
|     throw std::runtime_error(
 | |
|         std::string("compiled_args nyi, see [Note: Compiled Autograd] ") +
 | |
|         typeid(*this).name());
 | |
|   }
 | |
| };
 | |
| 
 | |
| struct TORCH_API PostAccumulateGradHook {
 | |
|   virtual ~PostAccumulateGradHook() = default;
 | |
|   virtual void operator()(const Variable& tensor) = 0;
 | |
|   // only implemented for python hooks on nodes, registers hook with compiled
 | |
|   // autograd
 | |
|   virtual void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) {
 | |
|     throw std::runtime_error(
 | |
|         std::string("not yet implemented for compiled autograd: ") +
 | |
|         typeid(*this).name());
 | |
|   }
 | |
| 
 | |
|   virtual void apply_with_saved(
 | |
|       Variable&,
 | |
|       torch::dynamo::autograd::SwapSavedVariables&) {
 | |
|     throw std::runtime_error(
 | |
|         std::string("not yet implemented for compiled autograd: ") +
 | |
|         typeid(*this).name());
 | |
|   }
 | |
| };
 | |
| 
 | |
| } // namespace torch::autograd
 |