functional compiled autograd (#144707)

This PR squashes together the following commits:

https://github.com/pytorch/pytorch/pull/144115
https://github.com/pytorch/pytorch/pull/143417
https://github.com/pytorch/pytorch/pull/143405
https://github.com/pytorch/pytorch/pull/143387
https://github.com/pytorch/pytorch/pull/143304
https://github.com/pytorch/pytorch/pull/143296

This is a refactor of compiled autograd to use "functional autograd". The end goal is that it gets compiled autograd's initial capture to stop specializing on Tensor metadata, therefore allowing compiled autograd to better handle Tensor subclasses.

For more information, please read the commit messages for each PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144707
Approved by: https://github.com/bdhirsh, https://github.com/xmfan, https://github.com/jansel
This commit is contained in:
rzou
2025-01-27 05:20:56 +00:00
committed by PyTorch MergeBot
parent 87fdadde1d
commit ea141d8134
28 changed files with 1809 additions and 223 deletions

View File

@ -34,8 +34,12 @@ using tensor_list = std::vector<at::Tensor>;
using variable_list = std::vector<Variable>;
using edge_list = std::vector<Edge>;
using saved_variable_list = std::vector<SavedVariable>;
using ivalue_list = std::vector<c10::IValue>;
using functional_apply_t = std::function<
variable_list(const variable_list&, const std::vector<c10::IValue>&)>;
using IndexRange = std::pair<size_t, size_t>;
using torch::dynamo::autograd::CompiledNodeArgs;
using torch::dynamo::autograd::PackedArgs;
using torch::dynamo::autograd::SwapSavedVariables;
// Custom deleter to prevent stack overflows.
@ -604,6 +608,12 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
std::string("apply_with_saved not implemented: ") + name());
}
// If this node is the AOTBackward node produced by torch.compile.
// Compiled Autograd special-cases on this information.
virtual bool is_aot_backward() const {
return false;
}
protected:
/// Performs the `Node`'s actual operation.
virtual variable_list apply(variable_list&& inputs) = 0;