From 457ff9b7aef78e3fd39965586c45ca003423bdca Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Mon, 10 Mar 2025 11:31:32 -0700 Subject: [PATCH] [reland][ca] side-effect free inital trace: compiled_args (#148376) This reverts commit ea12fc8a9ff7da808e0b661ca07e9d4ce75d04bc. Reland https://github.com/pytorch/pytorch/pull/147804, there was a bad import inserted by my linter. Differential Revision: [D70582747](https://our.internmc.facebook.com/intern/diff/D70582747) Pull Request resolved: https://github.com/pytorch/pytorch/pull/148376 Approved by: https://github.com/jansel --- tools/autograd/gen_autograd_functions.py | 4 +-- torch/csrc/autograd/custom_function.h | 2 +- torch/csrc/autograd/function.h | 6 ++-- torch/csrc/autograd/function_hook.h | 9 ++++-- .../autograd/functions/accumulate_grad.cpp | 4 +-- .../csrc/autograd/functions/accumulate_grad.h | 6 ++-- torch/csrc/autograd/functions/basic_ops.cpp | 10 ++++-- torch/csrc/autograd/functions/basic_ops.h | 10 ++++-- torch/csrc/autograd/functions/tensor.cpp | 4 +-- torch/csrc/autograd/functions/tensor.h | 4 +-- torch/csrc/autograd/python_function.cpp | 32 +++++++++---------- torch/csrc/autograd/python_function.h | 13 ++------ torch/csrc/autograd/python_hook.cpp | 8 ++--- torch/csrc/autograd/python_hook.h | 12 ++++--- torch/csrc/autograd/utils/lambda_post_hook.h | 2 +- torch/csrc/distributed/c10d/reducer.cpp | 28 ++++++++-------- torch/csrc/dynamo/compiled_autograd.h | 29 +++++++++++++---- 17 files changed, 104 insertions(+), 79 deletions(-) diff --git a/tools/autograd/gen_autograd_functions.py b/tools/autograd/gen_autograd_functions.py index 6d5973aeed15..ea275b58f0f6 100644 --- a/tools/autograd/gen_autograd_functions.py +++ b/tools/autograd/gen_autograd_functions.py @@ -67,7 +67,7 @@ struct TORCH_API ${op} : public ${superclass} { ${release_variables} } ${will_release_variables} - void compiled_args(CompiledNodeArgs& args) override; + void compiled_args(CompiledNodeArgs& args) const override; variable_list apply_with_saved(const variable_list& inputs, SwapSavedVariables& saved) override; ${saved_variables} ${saved_list_sizes} @@ -127,7 +127,7 @@ variable_list ${op}::apply(variable_list&& grads) { return ${op}_apply_functional(std::move(grads), needs_input_grad${,apply_functional_args}); } -void ${op}::compiled_args(CompiledNodeArgs& args) { +void ${op}::compiled_args(CompiledNodeArgs& args) const { ${compiled_args} } variable_list ${op}::apply_with_saved(const variable_list& grads, SwapSavedVariables& saved) { diff --git a/torch/csrc/autograd/custom_function.h b/torch/csrc/autograd/custom_function.h index 51d082595d8c..25e88cbf6cfe 100644 --- a/torch/csrc/autograd/custom_function.h +++ b/torch/csrc/autograd/custom_function.h @@ -284,7 +284,7 @@ struct CppNode : public Node { void set_ctx_grad_fn(const std::shared_ptr& node); void save_variables_to_ctx(); - void compiled_args(CompiledNodeArgs& args) override { + void compiled_args(CompiledNodeArgs& args) const override { // although neither of the 2 methods below have uniqueness guarantees // it is unlikely for them to collide at the same time args.collect(static_cast(typeid(T).hash_code())); diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h index abd11303eafe..106ff5ee0f2f 100644 --- a/torch/csrc/autograd/function.h +++ b/torch/csrc/autograd/function.h @@ -545,8 +545,8 @@ struct TORCH_API Node : std::enable_shared_from_this { return tensor_pre_hooks_; } - virtual std::unique_ptr& - tensor_post_acc_grad_hooks() noexcept { + virtual std::unique_ptr& tensor_post_acc_grad_hooks() + const noexcept { static std::unique_ptr empty = nullptr; return empty; } @@ -593,7 +593,7 @@ struct TORCH_API Node : std::enable_shared_from_this { // 2) Collect node information for specialization and caching // Implementations in subclasses should call args.collect() with all node // attrs. These functions are only called durring backward. - virtual void compiled_args(CompiledNodeArgs& args) { + virtual void compiled_args(CompiledNodeArgs& args) const { throw std::runtime_error( std::string("compiled_args not implemented: ") + name()); } diff --git a/torch/csrc/autograd/function_hook.h b/torch/csrc/autograd/function_hook.h index 4e8bba79a169..08d0b8d4c4cc 100644 --- a/torch/csrc/autograd/function_hook.h +++ b/torch/csrc/autograd/function_hook.h @@ -22,7 +22,8 @@ struct TORCH_API FunctionPreHook { virtual ~FunctionPreHook() = default; virtual variable_list operator()(const variable_list& grads) = 0; // only implemented for python hooks, registers hook with compiled autograd - virtual void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) { + virtual void compiled_args( + torch::dynamo::autograd::CompiledNodeArgs& args) const { throw std::runtime_error( std::string("compiled_args nyi, see [Note: Compiled Autograd] ") + typeid(*this).name()); @@ -35,7 +36,8 @@ struct TORCH_API FunctionPostHook { const variable_list& outputs /* grad_inputs */, const variable_list& inputs /* grad_outputs */) = 0; // only implemented for python hooks, registers hook with compiled autograd - virtual void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) { + virtual void compiled_args( + torch::dynamo::autograd::CompiledNodeArgs& args) const { throw std::runtime_error( std::string("compiled_args nyi, see [Note: Compiled Autograd] ") + typeid(*this).name()); @@ -47,7 +49,8 @@ struct TORCH_API PostAccumulateGradHook { virtual void operator()(const Variable& tensor) = 0; // only implemented for python hooks on nodes, registers hook with compiled // autograd - virtual void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) { + virtual void compiled_args( + torch::dynamo::autograd::CompiledNodeArgs& args) const { throw std::runtime_error( std::string("not yet implemented for compiled autograd: ") + typeid(*this).name()); diff --git a/torch/csrc/autograd/functions/accumulate_grad.cpp b/torch/csrc/autograd/functions/accumulate_grad.cpp index 4f4ee20efc67..3df791821556 100644 --- a/torch/csrc/autograd/functions/accumulate_grad.cpp +++ b/torch/csrc/autograd/functions/accumulate_grad.cpp @@ -66,12 +66,12 @@ auto AccumulateGrad::apply(variable_list&& grads) -> variable_list { return variable_list(); } -void AccumulateGrad::compiled_args(CompiledNodeArgs& args) { +void AccumulateGrad::compiled_args(CompiledNodeArgs& args) const { if (args.cond(variable.defined() && variable.requires_grad())) { args.collect(variable); args.collect(variable.grad()); } - auto& hook = tensor_post_acc_grad_hooks(); + const auto& hook = tensor_post_acc_grad_hooks(); if (hook != nullptr) { hook->compiled_args(args); } diff --git a/torch/csrc/autograd/functions/accumulate_grad.h b/torch/csrc/autograd/functions/accumulate_grad.h index 39ea91bf0e76..b1768ee2a93c 100644 --- a/torch/csrc/autograd/functions/accumulate_grad.h +++ b/torch/csrc/autograd/functions/accumulate_grad.h @@ -50,8 +50,8 @@ struct TORCH_API AccumulateGrad : public Node { return impl::hooks(variable); } - std::unique_ptr& tensor_post_acc_grad_hooks() noexcept - override { + std::unique_ptr& tensor_post_acc_grad_hooks() + const noexcept override { // NB: Since the AccumulateGrad Node is only a weak ref from the Tensor, // it can be destroyed even though the Tensor is still alive (contrary // to all other Nodes). So we must lazily read the Tensor hooks here. @@ -262,7 +262,7 @@ struct TORCH_API AccumulateGrad : public Node { } } - void compiled_args(CompiledNodeArgs& args) override; + void compiled_args(CompiledNodeArgs& args) const override; variable_list apply_with_saved( const variable_list& inputs, SwapSavedVariables& saved) override; diff --git a/torch/csrc/autograd/functions/basic_ops.cpp b/torch/csrc/autograd/functions/basic_ops.cpp index 2b17307925d8..a310be58e288 100644 --- a/torch/csrc/autograd/functions/basic_ops.cpp +++ b/torch/csrc/autograd/functions/basic_ops.cpp @@ -12,11 +12,15 @@ namespace torch::autograd { -auto Error::apply(variable_list&& inputs) -> variable_list { +variable_list Error::apply(variable_list&& inputs) { + return static_cast(this)->apply(std::move(inputs)); +} + +variable_list Error::apply(variable_list&& inputs) const { throw std::runtime_error(msg); } -void Error::compiled_args(CompiledNodeArgs& args) { +void Error::compiled_args(CompiledNodeArgs& args) const { // throw the error durring collect, the graph won't get compiled apply(variable_list()); } @@ -66,7 +70,7 @@ auto Identity::apply(variable_list&& grads) -> variable_list { return std::move(grads); } -void GraphRoot::compiled_args(CompiledNodeArgs& args) { +void GraphRoot::compiled_args(CompiledNodeArgs& args) const { args.collect(outputs); } variable_list GraphRoot::apply_with_saved( diff --git a/torch/csrc/autograd/functions/basic_ops.h b/torch/csrc/autograd/functions/basic_ops.h index d9e11b1f45fc..a00b7eab9068 100644 --- a/torch/csrc/autograd/functions/basic_ops.h +++ b/torch/csrc/autograd/functions/basic_ops.h @@ -18,8 +18,9 @@ struct TORCH_API Error : public Node { Error(std::string msg) : msg(std::move(msg)) {} variable_list apply(variable_list&& inputs) override; + variable_list apply(variable_list&& inputs) const; - void compiled_args(CompiledNodeArgs& args) override; + void compiled_args(CompiledNodeArgs& args) const override; variable_list apply_with_saved( const variable_list& inputs, SwapSavedVariables& saved) override; @@ -51,6 +52,7 @@ struct TORCH_API DelayedError : public Node { } variable_list apply(variable_list&& inputs) override; + variable_list apply(variable_list&& inputs) const; std::string msg; }; @@ -61,6 +63,7 @@ struct TORCH_API UndefinedGrad : public Node { } variable_list apply(variable_list&& inputs) override; + variable_list apply(variable_list&& inputs) const; }; struct TORCH_API UndefinedGradBackward : public Node { @@ -69,8 +72,9 @@ struct TORCH_API UndefinedGradBackward : public Node { UndefinedGradBackward() = default; variable_list apply(variable_list&& inputs) override; + variable_list apply(variable_list&& inputs) const; - void compiled_args(CompiledNodeArgs& args) override {} + void compiled_args(CompiledNodeArgs& args) const override {} variable_list apply_with_saved( const variable_list& inputs, SwapSavedVariables& saved) override { @@ -93,7 +97,7 @@ struct TORCH_API GraphRoot : public Node { return outputs; } - void compiled_args(CompiledNodeArgs& args) override; + void compiled_args(CompiledNodeArgs& args) const override; variable_list apply_with_saved( const variable_list& inputs, SwapSavedVariables& saved) override; diff --git a/torch/csrc/autograd/functions/tensor.cpp b/torch/csrc/autograd/functions/tensor.cpp index a06ebefb85c3..5f035c6f3320 100644 --- a/torch/csrc/autograd/functions/tensor.cpp +++ b/torch/csrc/autograd/functions/tensor.cpp @@ -60,7 +60,7 @@ auto CopyBackwards::apply(variable_list&& grads) -> variable_list { src_options); } -void CopyBackwards::compiled_args(CompiledNodeArgs& args) { +void CopyBackwards::compiled_args(CompiledNodeArgs& args) const { args.collect(src_options); } @@ -235,7 +235,7 @@ void CopySlices::release_variables() { fn = nullptr; } -void CopySlices::compiled_args(CompiledNodeArgs& args) { +void CopySlices::compiled_args(CompiledNodeArgs& args) const { TORCH_CHECK(!view_fn, "view_fn not supported by compiled autograd") TORCH_INTERNAL_ASSERT((bool)fn); args.collect(base); diff --git a/torch/csrc/autograd/functions/tensor.h b/torch/csrc/autograd/functions/tensor.h index 4b0c2190ed54..78a8819ad5f2 100644 --- a/torch/csrc/autograd/functions/tensor.h +++ b/torch/csrc/autograd/functions/tensor.h @@ -15,7 +15,7 @@ namespace torch::autograd { struct TORCH_API CopyBackwards : public Node { variable_list apply(variable_list&& grads) override; - void compiled_args(CompiledNodeArgs& args) override; + void compiled_args(CompiledNodeArgs& args) const override; variable_list apply_with_saved( const variable_list& inputs, SwapSavedVariables& saved) override; @@ -168,7 +168,7 @@ struct TORCH_API CopySlices : public Node { variable_list apply(variable_list&& inputs) override; void release_variables() override; - void compiled_args(CompiledNodeArgs& args) override; + void compiled_args(CompiledNodeArgs& args) const override; variable_list apply_with_saved( const variable_list& inputs, SwapSavedVariables& saved) override; diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index 4029e728cb05..978cf5c43f3a 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -185,9 +185,9 @@ auto PyNode::apply(variable_list&& inputs) -> variable_list { return to_variable_list(r.get(), is_variable_input); } -auto PyNode::defer_to_dynamo( +auto PyNode::apply_with_saved_impl( const variable_list& inputs, - const std::optional& compiler) -> variable_list { + const SwapSavedVariables& saved) -> variable_list { pybind11::gil_scoped_acquire gil; at::OptionalDeviceGuard _device_guard; THPFunction* py_fn = (THPFunction*)obj; @@ -235,24 +235,24 @@ auto PyNode::defer_to_dynamo( } THPObjectPtr saved_tensors(unpack_saved_variables( py_fn, [](const Variable& var) { return THPVariable_Wrap(var); })); - TORCH_INTERNAL_ASSERT( - _backward_idx.has_value(), - "indices should already be set by compiled_args, called before apply_with_saved"); + + auto [bwd_idx, maybe_bwd_state_idx] = saved.retrieve_pynode_objs(this); + PyObject* backward_state_idx = Py_None; - if (_backward_state_idx.has_value()) { - backward_state_idx = THPUtils_packInt64(_backward_state_idx.value()); + if (maybe_bwd_state_idx.has_value()) { + backward_state_idx = THPUtils_packUInt64(maybe_bwd_state_idx.value()); // this might be simplifiable now that we no longer inline Py_CLEAR(py_fn->compiled_autograd_backward_state); } THPObjectPtr r(PyObject_CallMethod( // NOLINTNEXTLINE(bugprone-unchecked-optional-access) - compiler.value(), + saved.get_py_compiler(), "proxy_call_backward", "OOOiOO", pyInputs.get(), fwdInputMetadatas.get(), saved_tensors.get(), - *_backward_idx, + bwd_idx, obj, backward_state_idx)); @@ -301,7 +301,7 @@ bool PyNode::is_aot_backward() const { return py::hasattr(py::getattr(handle, "_forward_cls"), "_aot_id"); } -void PyNode::compiled_args(CompiledNodeArgs& args) { +void PyNode::compiled_args(CompiledNodeArgs& args) const { static PyObject* method_name = PyUnicode_InternFromString("_compiled_autograd_key"); THPObjectPtr pykey(PyObject_CallMethodObjArgs(obj, method_name, nullptr)); @@ -346,14 +346,15 @@ void PyNode::compiled_args(CompiledNodeArgs& args) { args.collect(f->input_info); Py_INCREF(obj); - _backward_idx = args.add_backward(c10::SafePyObject(obj, getPyInterpreter())); - + c10::SafePyObject backward_obj(obj, getPyInterpreter()); + std::optional backward_state_obj; PyObject* bw_state = f->compiled_autograd_backward_state; if (args.cond(bw_state != nullptr)) { Py_INCREF(bw_state); - _backward_state_idx = args.add_backward_state( - c10::SafePyObject(bw_state, getPyInterpreter())); + backward_state_obj = c10::SafePyObject(bw_state, getPyInterpreter()); } + args.collect_pynode_objs( + this, std::move(backward_obj), std::move(backward_state_obj)); } variable_list PyNode::apply_with_saved( @@ -366,8 +367,7 @@ variable_list PyNode::apply_with_saved( saved.before(f->materialize_non_diff_grads); saved.before(f->output_info); saved.before(f->input_info); - variable_list result = - defer_to_dynamo(variable_list(inputs), saved.get_py_compiler()); + variable_list result = apply_with_saved_impl(variable_list(inputs), saved); saved.after(f->compiled_autograd_symints); saved.after(f->saved_variables); saved.after(f->needs_input_grad); diff --git a/torch/csrc/autograd/python_function.h b/torch/csrc/autograd/python_function.h index 671bc9758270..e24399c10aa3 100644 --- a/torch/csrc/autograd/python_function.h +++ b/torch/csrc/autograd/python_function.h @@ -36,9 +36,9 @@ struct PyNode : public Node { const std::vector& is_variable_input); variable_list apply(variable_list&& inputs) override; - variable_list defer_to_dynamo( + variable_list apply_with_saved_impl( const variable_list& inputs, - const std::optional& compiler); + const SwapSavedVariables& saved); void release_variables() override; std::string name() const override; @@ -46,7 +46,7 @@ struct PyNode : public Node { bool is_aot_backward() const override; - void compiled_args(CompiledNodeArgs& args) override; + void compiled_args(CompiledNodeArgs& args) const override; variable_list apply_with_saved( const variable_list& inputs, SwapSavedVariables& saved) override; @@ -54,13 +54,6 @@ struct PyNode : public Node { // THPFunction this Function is wrapping. Owning! PyObject* obj; - // The AutogradCompilerCall::hooks idx corresponding to this node's backward - std::optional _backward_idx; - - // The AutogradCompilerCall::hooks idx corresponding to this node's - // backward_state - std::optional _backward_state_idx; - // NOLINTNEXTLINE(bugprone-exception-escape) ~PyNode() override { // Can't use THPObjectPtr as a field in this class; destructor won't take diff --git a/torch/csrc/autograd/python_hook.cpp b/torch/csrc/autograd/python_hook.cpp index 2ba031ceb36f..3b2be3cb3f38 100644 --- a/torch/csrc/autograd/python_hook.cpp +++ b/torch/csrc/autograd/python_hook.cpp @@ -176,7 +176,7 @@ auto PyFunctionPostHook::operator()( return unwrap_variables(PyTuple_GetItem(tup.get(), 0)); } -void PyFunctionTensorPreHook::compiled_args(CompiledNodeArgs& args) { +void PyFunctionTensorPreHook::compiled_args(CompiledNodeArgs& args) const { PyObject *key = nullptr, *value = nullptr; Py_ssize_t pos = 0; Py_BEGIN_CRITICAL_SECTION(dict); @@ -189,7 +189,7 @@ void PyFunctionTensorPreHook::compiled_args(CompiledNodeArgs& args) { Py_END_CRITICAL_SECTION(); } -void PyFunctionPreHook::compiled_args(CompiledNodeArgs& args) { +void PyFunctionPreHook::compiled_args(CompiledNodeArgs& args) const { PyObject *key = nullptr, *value = nullptr; Py_ssize_t pos = 0; Py_BEGIN_CRITICAL_SECTION(dict); @@ -200,7 +200,7 @@ void PyFunctionPreHook::compiled_args(CompiledNodeArgs& args) { Py_END_CRITICAL_SECTION(); } -void PyFunctionPostHook::compiled_args(CompiledNodeArgs& args) { +void PyFunctionPostHook::compiled_args(CompiledNodeArgs& args) const { PyObject *key = nullptr, *value = nullptr; Py_ssize_t pos = 0; Py_BEGIN_CRITICAL_SECTION(dict); @@ -237,7 +237,7 @@ auto PyFunctionTensorPostAccGradHooks::operator()(const Variable& tensor) } void PyFunctionTensorPostAccGradHooks::compiled_args( - torch::dynamo::autograd::CompiledNodeArgs& args) { + torch::dynamo::autograd::CompiledNodeArgs& args) const { PyObject *key = nullptr, *value = nullptr; Py_ssize_t pos = 0; Py_BEGIN_CRITICAL_SECTION(dict); diff --git a/torch/csrc/autograd/python_hook.h b/torch/csrc/autograd/python_hook.h index a17a97924b2a..9b744509960d 100644 --- a/torch/csrc/autograd/python_hook.h +++ b/torch/csrc/autograd/python_hook.h @@ -14,7 +14,8 @@ struct PyFunctionTensorPreHook : public FunctionPreHook { PyFunctionTensorPreHook(PyObject* dict, size_t value_idx); ~PyFunctionTensorPreHook() override; variable_list operator()(const variable_list& values) override; - void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override; + void compiled_args( + torch::dynamo::autograd::CompiledNodeArgs& args) const override; PyObject* dict; size_t value_idx; }; @@ -23,7 +24,8 @@ struct PyFunctionPreHook : public FunctionPreHook { PyFunctionPreHook(PyObject* dict); ~PyFunctionPreHook() override; variable_list operator()(const variable_list& values) override; - void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override; + void compiled_args( + torch::dynamo::autograd::CompiledNodeArgs& args) const override; PyObject* dict; }; @@ -33,7 +35,8 @@ struct PyFunctionPostHook : public FunctionPostHook { variable_list operator()( const variable_list& outputs, const variable_list& inputs) override; - void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override; + void compiled_args( + torch::dynamo::autograd::CompiledNodeArgs& args) const override; PyObject* dict; }; @@ -45,7 +48,8 @@ struct PyFunctionTensorPostAccGradHooks : public PostAccumulateGradHook { PyFunctionTensorPostAccGradHooks(PyObject* dict); ~PyFunctionTensorPostAccGradHooks() override; void operator()(const Variable& tensor) override; - void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override; + void compiled_args( + torch::dynamo::autograd::CompiledNodeArgs& args) const override; void apply_with_saved( Variable& tensor, torch::dynamo::autograd::SwapSavedVariables& saved) override; diff --git a/torch/csrc/autograd/utils/lambda_post_hook.h b/torch/csrc/autograd/utils/lambda_post_hook.h index c2f47347a4cf..a98fab04afb9 100644 --- a/torch/csrc/autograd/utils/lambda_post_hook.h +++ b/torch/csrc/autograd/utils/lambda_post_hook.h @@ -27,7 +27,7 @@ class LambdaPostHook : public torch::autograd::FunctionPostHook { return fn_(outputs, inputs); } - void compiled_args(CompiledNodeArgs& args) override {} + void compiled_args(CompiledNodeArgs& args) const override {} protected: std::function fn_; diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index 7e85d55543a6..0cec78443ea3 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -183,21 +183,23 @@ Reducer::Reducer( #endif // Hook to execute after the gradient accumulator has executed. hooks_.emplace_back( - grad_accumulator->add_post_hook( - std::make_unique( - [this, variable_index]( - const torch::autograd::variable_list& outputs, - const torch::autograd::variable_list& /* unused */) { + grad_accumulator->add_post_hook(std::make_unique< + torch::autograd::utils:: + LambdaPostHook>( + [this, variable_index]( + const torch::autograd::variable_list& outputs, + const torch::autograd::variable_list& /* unused */) { #ifndef _WIN32 - this->rpc_context_.set( - ThreadLocalDistAutogradContext::getContextPtr()); + this->rpc_context_.set( + ThreadLocalDistAutogradContext::getContextPtr()); #endif - this->autograd_hook(variable_index); - return outputs; - }, - [=](torch::autograd::CompiledNodeArgs& args) { - // Make post_hook an noop if compiled_autograds is enabled. - })), + this->autograd_hook(variable_index); + return outputs; + }, + [=](torch::autograd::CompiledNodeArgs& args) { + TORCH_INTERNAL_ASSERT( + "Compiled autograd is not compatible with C++ DDP Reducer, please use torch._dynamo.config.optimize_ddp=\"python_reducer\"."); + })), grad_accumulator); // Map raw function pointer to parameter index. diff --git a/torch/csrc/dynamo/compiled_autograd.h b/torch/csrc/dynamo/compiled_autograd.h index 11a7334360c5..3db220bcecbc 100644 --- a/torch/csrc/dynamo/compiled_autograd.h +++ b/torch/csrc/dynamo/compiled_autograd.h @@ -349,6 +349,9 @@ struct AutogradCompilerCall { std::vector size_input_origins; std::unordered_map> sv_to_hooks; + // pynode -> backward and backward state idx + std::unordered_map>> + pynode_objs; }; class CompiledNodeArgs { @@ -619,12 +622,17 @@ class CompiledNodeArgs { typeid(*node), _specialization_key, _specialization_key_size); } - size_t add_backward(c10::SafePyObject&& obj) { - return _compiler.emplace_hook(std::move(obj)); - } - - size_t add_backward_state(c10::SafePyObject&& obj) { - return _compiler.emplace_hook(std::move(obj)); + void collect_pynode_objs( + const Node* pynode, + c10::SafePyObject&& bwd, + std::optional&& bwd_state) { + size_t bwd_idx = _compiler.emplace_hook(std::move(bwd)); + std::optional bwd_state_idx; + if (auto state = std::move(bwd_state); state.has_value()) { + bwd_state_idx = _compiler.emplace_hook(std::move(state.value())); + } + _compiler.pynode_objs.emplace( + pynode, std::make_pair(bwd_idx, bwd_state_idx)); } void add_tensor_pre_hook(c10::SafePyObject&& obj, int index) { @@ -743,6 +751,13 @@ class SwapSavedVariables { // cache-miss. It swaps any 'lifted' inputs (tensors, symints) to proxy nodes, // allows tracing to happen, then swaps them back afterwards. public: + std::pair> retrieve_pynode_objs( + Node* pynode) const { + auto it = compiler.pynode_objs.find(pynode); + TORCH_INTERNAL_ASSERT(it != compiler.pynode_objs.end()); + return it->second; + } + void before(at::Tensor& t) { TensorArg& arg = compiler.tensor_args.lookup(t); stashed_tensors.save(&t, std::move(t)); @@ -948,7 +963,7 @@ class SwapSavedVariables { const NodeCall& n) : compiler(c), state(s), py_compiler(p), curr_node_call(n) {} - PyObject* get_py_compiler() { + PyObject* get_py_compiler() const { return py_compiler; }