mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
functional compiled autograd (#144707)
This PR squashes together the following commits: https://github.com/pytorch/pytorch/pull/144115 https://github.com/pytorch/pytorch/pull/143417 https://github.com/pytorch/pytorch/pull/143405 https://github.com/pytorch/pytorch/pull/143387 https://github.com/pytorch/pytorch/pull/143304 https://github.com/pytorch/pytorch/pull/143296 This is a refactor of compiled autograd to use "functional autograd". The end goal is that it gets compiled autograd's initial capture to stop specializing on Tensor metadata, therefore allowing compiled autograd to better handle Tensor subclasses. For more information, please read the commit messages for each PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/144707 Approved by: https://github.com/bdhirsh, https://github.com/xmfan, https://github.com/jansel
This commit is contained in:
@ -34,8 +34,12 @@ using tensor_list = std::vector<at::Tensor>;
|
||||
using variable_list = std::vector<Variable>;
|
||||
using edge_list = std::vector<Edge>;
|
||||
using saved_variable_list = std::vector<SavedVariable>;
|
||||
using ivalue_list = std::vector<c10::IValue>;
|
||||
using functional_apply_t = std::function<
|
||||
variable_list(const variable_list&, const std::vector<c10::IValue>&)>;
|
||||
using IndexRange = std::pair<size_t, size_t>;
|
||||
using torch::dynamo::autograd::CompiledNodeArgs;
|
||||
using torch::dynamo::autograd::PackedArgs;
|
||||
using torch::dynamo::autograd::SwapSavedVariables;
|
||||
|
||||
// Custom deleter to prevent stack overflows.
|
||||
@ -604,6 +608,12 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
|
||||
std::string("apply_with_saved not implemented: ") + name());
|
||||
}
|
||||
|
||||
// If this node is the AOTBackward node produced by torch.compile.
|
||||
// Compiled Autograd special-cases on this information.
|
||||
virtual bool is_aot_backward() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
protected:
|
||||
/// Performs the `Node`'s actual operation.
|
||||
virtual variable_list apply(variable_list&& inputs) = 0;
|
||||
|
Reference in New Issue
Block a user